@@ -44,26 +44,56 @@ function Base.show(io::IO, alg::NonlinearSolvePolyAlgorithm{pType, N}) where {pT
44
44
end
45
45
end
46
46
47
- @concrete mutable struct NonlinearSolvePolyAlgorithmCache{iip, N} < :
48
- AbstractNonlinearSolveCache{iip, false }
47
+ @concrete mutable struct NonlinearSolvePolyAlgorithmCache{iip, N, timeit } < :
48
+ AbstractNonlinearSolveCache{iip, timeit }
49
49
caches
50
50
alg
51
+ best:: Int
51
52
current:: Int
53
+ nsteps:: Int
54
+ total_time:: Float64
55
+ maxtime
56
+ retcode:: ReturnCode.T
57
+ force_stop:: Bool
58
+ maxiters:: Int
59
+ end
60
+
61
+ function Base. show (
62
+ io:: IO , cache:: NonlinearSolvePolyAlgorithmCache{pType, N} ) where {pType, N}
63
+ problem_kind = ifelse (pType == :NLS , " NonlinearProblem" , " NonlinearLeastSquaresProblem" )
64
+ println (io, " NonlinearSolvePolyAlgorithmCache for $(problem_kind) with $(N) algorithms" )
65
+ best_alg = ifelse (cache. best == - 1 , " nothing" , cache. best)
66
+ println (io, " Best algorithm: $(best_alg) " )
67
+ println (io, " Current algorithm: $(cache. current) " )
68
+ println (io, " nsteps: $(cache. nsteps) " )
69
+ println (io, " retcode: $(cache. retcode) " )
70
+ __show_cache (io, cache. caches[cache. current], 0 )
52
71
end
53
72
54
73
function reinit_cache! (cache:: NonlinearSolvePolyAlgorithmCache , args... ; kwargs... )
55
74
foreach (c -> reinit_cache! (c, args... ; kwargs... ), cache. caches)
56
75
cache. current = 1
76
+ cache. nsteps = 0
77
+ cache. total_time = 0.0
57
78
end
58
79
59
80
for (probType, pType) in ((:NonlinearProblem , :NLS ), (:NonlinearLeastSquaresProblem , :NLLS ))
60
81
algType = NonlinearSolvePolyAlgorithm{pType}
61
82
@eval begin
62
- function SciMLBase. __init (
63
- prob:: $probType , alg:: $algType{N} , args... ; kwargs... ) where {N}
64
- return NonlinearSolvePolyAlgorithmCache {isinplace(prob), N} (
65
- map (solver -> SciMLBase. __init (prob, solver, args... ; kwargs... ), alg. algs),
66
- alg, 1 )
83
+ function SciMLBase. __init (prob:: $probType , alg:: $algType{N} , args... ;
84
+ maxtime = nothing , maxiters = 1000 , kwargs... ) where {N}
85
+ return NonlinearSolvePolyAlgorithmCache {isinplace(prob), N, maxtime !== nothing} (
86
+ map (solver -> SciMLBase. __init (prob, solver, args... ; maxtime, kwargs... ),
87
+ alg. algs),
88
+ alg,
89
+ - 1 ,
90
+ 1 ,
91
+ 0 ,
92
+ 0.0 ,
93
+ maxtime,
94
+ ReturnCode. Default,
95
+ false ,
96
+ maxiters)
67
97
end
68
98
end
69
99
end
89
119
fu = get_fu ($ (cache_syms[i]))
90
120
return SciMLBase. build_solution (
91
121
$ (sol_syms[i]). prob, cache. alg, u, fu;
92
- retcode = ReturnCode . Success , stats,
122
+ retcode = $ (sol_syms[i]) . retcode , stats,
93
123
original = $ (sol_syms[i]), trace = $ (sol_syms[i]). trace)
94
124
end
95
125
cache. current = $ (i + 1 )
@@ -103,12 +133,11 @@ end
103
133
end
104
134
push! (calls,
105
135
quote
106
- retcode = ReturnCode. MaxIters
107
-
108
136
fus = tuple ($ (Tuple (resids)... ))
109
137
minfu, idx = __findmin (cache. caches[1 ]. internalnorm, fus)
110
138
stats = cache. caches[idx]. stats
111
- u = cache. caches[idx]. u
139
+ u = get_u (cache. caches[idx])
140
+ retcode = cache. caches[idx]. retcode
112
141
113
142
return SciMLBase. build_solution (cache. caches[idx]. prob, cache. alg, u, fus[idx];
114
143
retcode, stats, cache. caches[idx]. trace)
117
146
return Expr (:block , calls... )
118
147
end
119
148
149
+ @generated function __step! (
150
+ cache:: NonlinearSolvePolyAlgorithmCache{iip, N} , args... ; kwargs... ) where {iip, N}
151
+ calls = []
152
+ cache_syms = [gensym (" cache" ) for i in 1 : N]
153
+ for i in 1 : N
154
+ push! (calls,
155
+ quote
156
+ $ (cache_syms[i]) = cache. caches[$ (i)]
157
+ if $ (i) == cache. current
158
+ __step! ($ (cache_syms[i]), args... ; kwargs... )
159
+ $ (cache_syms[i]). nsteps += 1
160
+ if ! not_terminated ($ (cache_syms[i]))
161
+ if SciMLBase. successful_retcode ($ (cache_syms[i]). retcode)
162
+ cache. best = $ (i)
163
+ cache. force_stop = true
164
+ cache. retcode = $ (cache_syms[i]). retcode
165
+ else
166
+ cache. current = $ (i + 1 )
167
+ end
168
+ end
169
+ return
170
+ end
171
+ end )
172
+ end
173
+
174
+ push! (calls,
175
+ quote
176
+ if ! (1 ≤ cache. current ≤ length (cache. caches))
177
+ minfu, idx = __findmin (first (cache. caches). internalnorm, cache. caches)
178
+ cache. best = idx
179
+ cache. retcode = cache. caches[cache. best]. retcode
180
+ cache. force_stop = true
181
+ return
182
+ end
183
+ end )
184
+
185
+ return Expr (:block , calls... )
186
+ end
187
+
120
188
for (probType, pType) in ((:NonlinearProblem , :NLS ), (:NonlinearLeastSquaresProblem , :NLLS ))
121
189
algType = NonlinearSolvePolyAlgorithm{pType}
122
190
@eval begin
0 commit comments