@@ -15,34 +15,36 @@ function Base.show(io::IO, mss::AbstractMultiStepScheme)
15
15
print (io, " MultiStepSchemes.$(string (nameof (typeof (mss)))[3 : end ]) " )
16
16
end
17
17
18
- alg_steps (:: Type{T} ) where {T <: AbstractMultiStepScheme } = alg_steps (T ())
18
+ newton_steps (:: Type{T} ) where {T <: AbstractMultiStepScheme } = newton_steps (T ())
19
19
20
20
struct __PotraPtak3 <: AbstractMultiStepScheme end
21
21
const PotraPtak3 = __PotraPtak3 ()
22
22
23
- alg_steps (:: __PotraPtak3 ) = 2
23
+ newton_steps (:: __PotraPtak3 ) = 2
24
24
nintermediates (:: __PotraPtak3 ) = 1
25
25
26
26
@kwdef @concrete struct __SinghSharma4 <: AbstractMultiStepScheme
27
27
jvp_autodiff = nothing
28
28
end
29
29
const SinghSharma4 = __SinghSharma4 ()
30
30
31
- alg_steps (:: __SinghSharma4 ) = 3
31
+ newton_steps (:: __SinghSharma4 ) = 4
32
+ nintermediates (:: __SinghSharma4 ) = 2
32
33
33
34
@kwdef @concrete struct __SinghSharma5 <: AbstractMultiStepScheme
34
35
jvp_autodiff = nothing
35
36
end
36
37
const SinghSharma5 = __SinghSharma5 ()
37
38
38
- alg_steps (:: __SinghSharma5 ) = 3
39
+ newton_steps (:: __SinghSharma5 ) = 4
40
+ nintermediates (:: __SinghSharma5 ) = 2
39
41
40
42
@kwdef @concrete struct __SinghSharma7 <: AbstractMultiStepScheme
41
43
jvp_autodiff = nothing
42
44
end
43
45
const SinghSharma7 = __SinghSharma7 ()
44
46
45
- alg_steps (:: __SinghSharma7 ) = 4
47
+ newton_steps (:: __SinghSharma7 ) = 6
46
48
47
49
@generated function display_name (alg:: T ) where {T <: AbstractMultiStepScheme }
48
50
res = Symbol (first (split (last (split (string (T), " ." )), " {" ; limit = 2 ))[3 : end ])
@@ -75,6 +77,8 @@ supports_trust_region(::GenericMultiStepDescent) = false
75
77
fus
76
78
internal_cache
77
79
internal_caches
80
+ extra
81
+ extras
78
82
scheme:: S
79
83
timer
80
84
nf:: Int
@@ -91,49 +95,37 @@ function __reinit_internal!(cache::GenericMultiStepDescentCache, args...; p = ca
91
95
end
92
96
93
97
function __internal_multistep_caches (
94
- scheme:: MSS.__PotraPtak3 , alg:: GenericMultiStepDescent ,
95
- prob, args... ; shared:: Val{N} = Val (1 ), kwargs... ) where {N}
98
+ scheme:: Union{MSS.__PotraPtak3, MSS.__SinghSharma4, MSS.__SinghSharma5} ,
99
+ alg:: GenericMultiStepDescent , prob, args... ;
100
+ shared:: Val{N} = Val (1 ), kwargs... ) where {N}
96
101
internal_descent = NewtonDescent (; alg. linsolve, alg. precs)
97
- internal_cache = __internal_init (
102
+ return @shared_caches N __internal_init (
98
103
prob, internal_descent, args... ; kwargs... , shared = Val (2 ))
99
- internal_caches = N ≤ 1 ? nothing :
100
- map (2 : N) do i
101
- __internal_init (prob, internal_descent, args... ; kwargs... , shared = Val (2 ))
102
- end
103
- return internal_cache, internal_caches
104
104
end
105
105
106
+ __extras_cache (:: MSS.AbstractMultiStepScheme , args... ; kwargs... ) = nothing , nothing
107
+
106
108
function __internal_init (prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} ,
107
109
alg:: GenericMultiStepDescent , J, fu, u; shared:: Val{N} = Val (1 ),
108
110
pre_inverted:: Val{INV} = False, linsolve_kwargs = (;),
109
111
abstol = nothing , reltol = nothing , timer = get_timer_output (),
110
112
kwargs... ) where {INV, N}
111
- @bb δu = similar (u)
112
- δus = N ≤ 1 ? nothing : map (2 : N) do i
113
- @bb δu_ = similar (u)
114
- end
115
- fu_cache = ntuple (MSS. nintermediates (alg. scheme)) do i
113
+ δu, δus = @shared_caches N (@bb δu = similar (u))
114
+ fu_cache, fus_cache = @shared_caches N (ntuple (MSS. nintermediates (alg. scheme)) do i
116
115
@bb xx = similar (fu)
117
- end
118
- fus_cache = N ≤ 1 ? nothing : map (2 : N) do i
119
- ntuple (MSS. nintermediates (alg. scheme)) do j
120
- @bb xx = similar (fu)
121
- end
122
- end
123
- u_cache = ntuple (MSS. nintermediates (alg. scheme)) do i
116
+ end )
117
+ u_cache, us_cache = @shared_caches N (ntuple (MSS. nintermediates (alg. scheme)) do i
124
118
@bb xx = similar (u)
125
- end
126
- us_cache = N ≤ 1 ? nothing : map (2 : N) do i
127
- ntuple (MSS. nintermediates (alg. scheme)) do j
128
- @bb xx = similar (u)
129
- end
130
- end
119
+ end )
131
120
internal_cache, internal_caches = __internal_multistep_caches (
132
121
alg. scheme, alg, prob, J, fu, u; shared, pre_inverted, linsolve_kwargs,
133
122
abstol, reltol, timer, kwargs... )
123
+ extra, extras = __extras_cache (
124
+ alg. scheme, alg, prob, J, fu, u; shared, pre_inverted, linsolve_kwargs,
125
+ abstol, reltol, timer, kwargs... )
134
126
return GenericMultiStepDescentCache (
135
127
prob. f, prob. p, δu, δus, u_cache, us_cache, fu_cache, fus_cache,
136
- internal_cache, internal_caches, alg. scheme, timer, 0 )
128
+ internal_cache, internal_caches, extra, extras, alg. scheme, timer, 0 )
137
129
end
138
130
139
131
function __internal_solve! (cache:: GenericMultiStepDescentCache{MSS.__PotraPtak3, INV} , J,
0 commit comments