Skip to content

Commit 44b73d1

Browse files
standardized solution type
1 parent 7c4a49b commit 44b73d1

File tree

3 files changed

+11
-51
lines changed

3 files changed

+11
-51
lines changed

src/GalacticOptim.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ using ArrayInterface, Base.Iterators
1111

1212
using ForwardDiff: DEFAULT_CHUNK_THRESHOLD
1313
import DiffEqBase: OptimizationProblem, OptimizationFunction, AbstractADType
14-
import SciMLBase: AbstractNoTimeSolution, build_solution, AbstractNonlinearProblem
1514

1615
import ModelingToolkit
1716
import ModelingToolkit: AutoModelingToolkit

src/solve.jl

Lines changed: 10 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,3 @@
1-
abstract type AbstractOptimizationSolution{T, N} <: AbstractNoTimeSolution{T, N} end
2-
3-
struct OptimizationSolution{T, N, uType, P, A, Tf} <: AbstractOptimizationSolution{T, N}
4-
u::uType # minimizer
5-
prob::P # optimization problem
6-
alg::A # algorithm
7-
minimum::Tf
8-
initial_x::Array{Float64,1}
9-
retcode::Symbol
10-
original::String # original output of the optimizer
11-
end
12-
13-
function build_solution(prob::AbstractNonlinearProblem,
14-
alg, u, minimum;
15-
initial_x = prob.u0,
16-
retcode = :Default,
17-
original = nothing,
18-
kwargs...)
19-
20-
T = eltype(eltype(u))
21-
N = ndims(u)
22-
23-
OptimizationSolution{T, N, typeof(u), typeof(prob), typeof(alg),
24-
typeof(minimum)}
25-
(u, prob, alg, minimum, initial_x,
26-
retcode, original)
27-
end
28-
29-
function Base.show(io::IO, A::AbstractNoTimeSolution)
30-
31-
@printf io "\n * Status: %s\n\n" A.retcode === :Success ? "success" : "failure"
32-
@printf io " * Candidate solution\n"
33-
@printf io " Final objective value: %e\n" A.minimum
34-
@printf io "\n"
35-
@printf io " * Found with\n"
36-
@printf io " Algorithm: %s\n" A.alg
37-
return
38-
end
39-
401
struct NullData end
412
const DEFAULT_DATA = Iterators.cycle((NullData(),))
423
Base.iterate(::NullData, i=1) = nothing
@@ -46,14 +7,6 @@ get_maxiters(data) = Iterators.IteratorSize(typeof(DEFAULT_DATA)) isa Iterators.
467
Iterators.IteratorSize(typeof(DEFAULT_DATA)) isa Iterators.SizeUnknown ?
478
typemax(Int) : length(data)
489

49-
struct EnsembleOptimizationProblem
50-
prob::Array{T, 1} where T<:OptimizationProblem
51-
end
52-
53-
function DiffEqBase.solve(prob::Union{OptimizationProblem,EnsembleOptimizationProblem}, opt, args...;kwargs...)
54-
__solve(prob, opt, args...; kwargs...)
55-
end
56-
5710
#=
5811
function update!(x::AbstractArray, x̄::AbstractArray{<:ForwardDiff.Dual})
5912
x .-= x̄
@@ -167,6 +120,7 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;
167120

168121
_time = time()
169122

123+
build_solution(prob, opt, θ, x[1])
170124
# here should be build_solution to create the output message
171125
end
172126

@@ -225,8 +179,15 @@ function __solve(prob::OptimizationProblem, opt::Optim.AbstractOptimizer,
225179
optim_f = TwiceDifferentiable(_loss, (G, θ) -> f.grad(G, θ, cur...), fg!, (H,θ) -> f.hess(H,θ,cur...), prob.u0)
226180
end
227181

228-
Optim.optimize(optim_f, prob.u0, opt, !(isnothing(maxiters)) ? Optim.Options(;extended_trace = true, callback = _cb, iterations = maxiters, kwargs...)
229-
: Optim.Options(;extended_trace = true, callback = _cb, kwargs...))
182+
original = Optim.optimize(optim_f, prob.u0, opt,
183+
!(isnothing(maxiters)) ?
184+
Optim.Options(;extended_trace = true,
185+
callback = _cb,
186+
iterations = maxiters,
187+
kwargs...) :
188+
Optim.Options(;extended_trace = true,
189+
callback = _cb, kwargs...))
190+
build_solution(prob, opt, original.minimizer, original.minimum; original=original)
230191
end
231192

232193
function __solve(prob::OptimizationProblem, opt::Union{Optim.Fminbox,Optim.SAMIN},

test/diffeqfluxtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ callback = function (p, l, pred)
4040

4141
# using `remake` to re-create our `prob` with current parameters `p`
4242
remade_solution = solve(remake(prob_ode, p = p), Tsit5(), saveat = tsteps)
43-
43+
4444
# Tell sciml_train to not halt the optimization. If return true, then
4545
# optimization stops.
4646
return false

0 commit comments

Comments
 (0)