Skip to content

Commit 9647cae

Browse files
committed
fix merge
1 parent a6f117e commit 9647cae

File tree

2 files changed

+81
-58
lines changed

2 files changed

+81
-58
lines changed

ext/MTKJuMPControlExt.jl

Lines changed: 72 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,37 @@ using DiffEqDevTools, DiffEqBase, SciMLBase
55
using LinearAlgebra
66
const MTK = ModelingToolkit
77

8-
abstract type AbstractOptimalControlProblem{uType, tType, isinplace} <: SciMLBase.AbstractODEProblem{uType, tType, isinplace} end
9-
10-
struct JuMPControlProblem{uType, tType, isinplace, P, F, K} <: AbstractOptimalControlProblem{uType, tType, isinplace}
11-
f::F
12-
u0::uType
13-
tspan::tType
14-
p::P
15-
model::InfiniteModel
16-
kwargs::K
17-
18-
function JuMPControlProblem(f, u0, tspan, p, model; kwargs...)
19-
new{typeof(u0), typeof(tspan), SciMLBase.isinplace(f), typeof(p), typeof(f), typeof(kwargs)}(f, u0, tspan, p, model, kwargs)
20-
end
8+
abstract type AbstractOptimalControlProblem{uType, tType, isinplace} <:
9+
SciMLBase.AbstractODEProblem{uType, tType, isinplace} end
10+
11+
struct JuMPControlProblem{uType, tType, isinplace, P, F, K} <:
12+
AbstractOptimalControlProblem{uType, tType, isinplace}
13+
f::F
14+
u0::uType
15+
tspan::tType
16+
p::P
17+
model::InfiniteModel
18+
kwargs::K
19+
20+
function JuMPControlProblem(f, u0, tspan, p, model; kwargs...)
21+
new{typeof(u0), typeof(tspan), SciMLBase.isinplace(f),
22+
typeof(p), typeof(f), typeof(kwargs)}(f, u0, tspan, p, model, kwargs)
23+
end
2124
end
2225

23-
struct InfiniteOptControlProblem{uType, tType, isinplace, P, F, K} <: AbstractOptimalControlProblem{uType, tType, isinplace}
24-
f::F
25-
u0::uType
26-
tspan::tType
27-
p::P
28-
model::InfiniteModel
29-
kwargs::K
30-
31-
function InfiniteOptControlProblem(f, u0, tspan, p, model; kwargs...)
32-
new{typeof(u0), typeof(tspan), SciMLBase.isinplace(f), typeof(p), typeof(f), typeof(kwargs)}(f, u0, tspan, p, model, kwargs)
33-
end
26+
struct InfiniteOptControlProblem{uType, tType, isinplace, P, F, K} <:
27+
AbstractOptimalControlProblem{uType, tType, isinplace}
28+
f::F
29+
u0::uType
30+
tspan::tType
31+
p::P
32+
model::InfiniteModel
33+
kwargs::K
34+
35+
function InfiniteOptControlProblem(f, u0, tspan, p, model; kwargs...)
36+
new{typeof(u0), typeof(tspan), SciMLBase.isinplace(f),
37+
typeof(p), typeof(f), typeof(kwargs)}(f, u0, tspan, p, model, kwargs)
38+
end
3439
end
3540

3641
"""
@@ -48,7 +53,9 @@ The constraints are:
4853
- The set of user constraints passed to the ODESystem via `constraints`
4954
- The solver constraints that encode the time-stepping used by the solver
5055
"""
51-
function MTK.JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("dt must be provided for JuMPControlProblem."), guesses = Dict(), kwargs...)
56+
function MTK.JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap;
57+
dt = error("dt must be provided for JuMPControlProblem."),
58+
guesses = Dict(), kwargs...)
5259
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
5360
f, u0, p = MTK.process_SciMLProblem(ODEFunction, sys, _u0map, pmap;
5461
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
@@ -67,7 +74,9 @@ of the interpolation arrays.
6774
Related to `JuMPControlProblem`, but directly adds the differential equations
6875
of the system as derivative constraints, rather than using a solver tableau.
6976
"""
70-
function MTK.InfiniteOptControlProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("dt must be provided for InfiniteOptControlProblem."), guesses = Dict(), kwargs...)
77+
function MTK.InfiniteOptControlProblem(sys::ODESystem, u0map, tspan, pmap;
78+
dt = error("dt must be provided for InfiniteOptControlProblem."),
79+
guesses = Dict(), kwargs...)
7180
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
7281
f, u0, p = MTK.process_SciMLProblem(ODEFunction, sys, _u0map, pmap;
7382
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
@@ -87,15 +96,16 @@ function init_model(sys, tsteps, u0map, u0)
8796
ctrls = controls(sys)
8897
states = unknowns(sys)
8998
model = InfiniteModel()
90-
@infinite_parameter(model, t in [tsteps[1], tsteps[end]], num_supports = length(tsteps))
99+
@infinite_parameter(model, t in [tsteps[1], tsteps[end]], num_supports=length(tsteps))
91100
@variable(model, U[i = 1:length(states)], Infinite(t))
92101
@variable(model, V[1:length(ctrls)], Infinite(t))
93102

94103
add_jump_cost_function!(model, sys)
95104
add_user_constraints!(model, sys)
96105

97106
stidxmap = Dict([v => i for (i, v) in enumerate(states)])
98-
u0_idxs = has_alg_eqs(sys) ? collect(1:length(states)) : [stidxmap[k] for (k, v) in u0map]
107+
u0_idxs = has_alg_eqs(sys) ? collect(1:length(states)) :
108+
[stidxmap[k] for (k, v) in u0map]
99109
add_initial_constraints!(model, u0, u0_idxs, tsteps[1])
100110
return model
101111
end
@@ -119,15 +129,15 @@ function add_jump_cost_function!(model, sys)
119129
subval = isequal(t, iv) ? model[:U][idx] : model[:U][idx](t)
120130
jcosts = map(c -> Symbolics.substitute(c, Dict(x(t) => subval)), jcosts)
121131
end
122-
132+
123133
for ct in controls(sys)
124134
p = operation(ct)
125135
t = only(arguments(ct))
126136
idx = cidxmap[p(iv)]
127137
subval = isequal(t, iv) ? model[:V][idx] : model[:V][idx](t)
128138
jcosts = map(c -> Symbolics.substitute(c, Dict(p(t) => subval)), jcosts)
129139
end
130-
140+
131141
@objective(model, Min, consolidate(jcosts))
132142
end
133143

@@ -153,23 +163,24 @@ function add_user_constraints!(model, sys)
153163
t = only(arguments(ct))
154164
idx = cidxmap[p(iv)]
155165
subval = isequal(t, iv) ? model[:V][idx] : model[:V][idx](t)
156-
jconstraints = map(c -> Symbolics.substitute(jconstraints, Dict(p(t) => subval)), jconstriants)
166+
jconstraints = map(
167+
c -> Symbolics.substitute(jconstraints, Dict(p(t) => subval)), jconstriants)
157168
end
158169

159170
for (i, cons) in enumerate(jconstraints)
160171
if cons isa Equation
161-
@constraint(model, cons.lhs - cons.rhs == 0, base_name = "user[$i]")
162-
elseif cons.relational_op === Symbolics.geq
163-
@constraint(model, cons.lhs - cons.rhs 0, base_name = "user[$i]")
172+
@constraint(model, cons.lhs - cons.rhs==0, base_name="user[$i]")
173+
elseif cons.relational_op === Symbolics.geq
174+
@constraint(model, cons.lhs - cons.rhs0, base_name="user[$i]")
164175
else
165-
@constraint(model, cons.lhs - cons.rhs 0, base_name = "user[$i]")
176+
@constraint(model, cons.lhs - cons.rhs0, base_name="user[$i]")
166177
end
167178
end
168179
end
169180

170181
function add_initial_constraints!(model, u0, u0_idxs, ts)
171182
U = model[:U]
172-
@constraint(model, initial[i in u0_idxs], U[i](ts) == u0[i])
183+
@constraint(model, initial[i in u0_idxs], U[i](ts)==u0[i])
173184
end
174185

175186
is_explicit(tableau) = tableau isa DiffEqDevTools.ExplicitRKTableau
@@ -193,12 +204,12 @@ function add_infopt_solve_constraints!(model, sys, pmap)
193204
diff_eqs = map(e -> Symbolics.substitute(e, submap), diff_eqs)
194205
diff_eqs = map(e -> Symbolics.substitute(e, diffsubmap), diff_eqs)
195206
end
196-
@constraint(model, D[i = 1:length(diff_eqs)], diff_eqs[i].lhs == diff_eqs[i].rhs)
207+
@constraint(model, D[i = 1:length(diff_eqs)], diff_eqs[i].lhs==diff_eqs[i].rhs)
197208

198209
# Algebraic equations
199210
alg_eqs = alg_equations(sys)
200211
alg_eqs = map(e -> Symbolics.substitute(e, submap), alg_eqs)
201-
@constraint(model, A[i = 1:length(alg_eqs)], alg_eqs[i].lhs == alg_eqs[i].rhs)
212+
@constraint(model, A[i = 1:length(alg_eqs)], alg_eqs[i].lhs==alg_eqs[i].rhs)
202213
end
203214

204215
function add_jump_solve_constraints!(prob, tableau)
@@ -219,26 +230,29 @@ function add_jump_solve_constraints!(prob, tableau)
219230
K = Any[]
220231
for τ in tsteps
221232
for (i, h) in enumerate(c)
222-
ΔU = sum([A[i, j] * K[j] for j in 1:i-1], init = zeros(nᵤ))
223-
Uₙ = [U[i](τ) + ΔU[i]*dt for i in 1:nᵤ]
224-
Kₙ = f(Uₙ, p, τ + h*dt)
233+
ΔU = sum([A[i, j] * K[j] for j in 1:(i - 1)], init = zeros(nᵤ))
234+
Uₙ = [U[i](τ) + ΔU[i] * dt for i in 1:nᵤ]
235+
Kₙ = f(Uₙ, p, τ + h * dt)
225236
push!(K, Kₙ)
226237
end
227-
ΔU = dt*sum([α[i] * K[i] for i in 1:length(α)])
228-
@constraint(model, [n = 1:nᵤ], U[n](τ) + ΔU[n] == U[n](τ + dt), base_name = "solve_time_")
238+
ΔU = dt * sum([α[i] * K[i] for i in 1:length(α)])
239+
@constraint(model, [n = 1:nᵤ], U[n](τ) + ΔU[n]==U[n](τ + dt),
240+
base_name="solve_time_")
229241
empty!(K)
230242
end
231243
else
232-
@variable(model, K[1:length(α), 1:nᵤ], Infinite(t), start = tsteps[1])
244+
@variable(model, K[1:length(α), 1:nᵤ], Infinite(t), start=tsteps[1])
233245
for τ in tsteps
234246
ΔUs = A * K
235247
for (i, h) in enumerate(c)
236248
ΔU = ΔUs[i, :]
237-
Uₙ = [U[j] + ΔU[j]*dt for j in 1:nᵤ]
238-
@constraint(model, [j in 1:nᵤ], K[i, j] == f(Uₙ, p, τ + h*dt)[j], DomainRestrictions(t => τ), base_name = "solve_K()")
249+
Uₙ = [U[j] + ΔU[j] * dt for j in 1:nᵤ]
250+
@constraint(model, [j in 1:nᵤ], K[i, j]==f(Uₙ, p, τ + h * dt)[j],
251+
DomainRestrictions(t => τ), base_name="solve_K()")
239252
end
240-
ΔU = dt*sum([α[i] * K[i, :] for i in 1:length(α)])
241-
@constraint(model, [n = 1:nᵤ], U[n] + ΔU[n] == U[n](τ + dt), DomainRestrictions(t => τ), base_name = "solve_U()")
253+
ΔU = dt * sum([α[i] * K[i, :] for i in 1:length(α)])
254+
@constraint(model, [n = 1:nᵤ], U[n] + ΔU[n]==U[n](τ + dt),
255+
DomainRestrictions(t => τ), base_name="solve_U()")
242256
end
243257
end
244258
end
@@ -271,7 +285,7 @@ function DiffEqBase.solve(prob::JuMPControlProblem, jump_solver, ode_solver::Sym
271285
delete(model, con)
272286
end
273287
end
274-
unregister(model, :K)
288+
unregister(model, :K)
275289
for var in all_variables(model)
276290
if occursin("K", JuMP.name(var))
277291
delete(model, var)
@@ -284,7 +298,8 @@ end
284298
"""
285299
`derivative_method` kwarg refers to the method used by InfiniteOpt to compute derivatives. The list of possible options can be found at https://infiniteopt.github.io/InfiniteOpt.jl/stable/guide/derivative/. Defaults to FiniteDifference(Backward()).
286300
"""
287-
function DiffEqBase.solve(prob::InfiniteOptControlProblem, jump_solver; derivative_method = InfiniteOpt.FiniteDifference(Backward()))
301+
function DiffEqBase.solve(prob::InfiniteOptControlProblem, jump_solver;
302+
derivative_method = InfiniteOpt.FiniteDifference(Backward()))
288303
set_derivative_method(prob.model[:t], derivative_method)
289304
_solve(prob, jump_solver, derivative_method)
290305
end
@@ -296,7 +311,8 @@ function _solve(prob::AbstractOptimalControlProblem, jump_solver, solver)
296311

297312
tstatus = termination_status(model)
298313
pstatus = primal_status(model)
299-
!has_values(model) && error("Model not solvable; please report this to github.com/SciML/ModelingToolkit.jl.")
314+
!has_values(model) &&
315+
error("Model not solvable; please report this to github.com/SciML/ModelingToolkit.jl.")
300316

301317
ts = supports(model[:t])
302318
U_vals = value.(model[:U])
@@ -310,9 +326,12 @@ function _solve(prob::AbstractOptimalControlProblem, jump_solver, solver)
310326
input_sol = DiffEqBase.build_solution(prob, solver, ts, V_vals)
311327
end
312328

313-
if !(pstatus === FEASIBLE_POINT && (tstatus === OPTIMAL || tstatus === LOCALLY_SOLVED || tstatus === ALMOST_OPTIMAL || tstatus === ALMOST_LOCALLY_SOLVED))
329+
if !(pstatus === FEASIBLE_POINT &&
330+
(tstatus === OPTIMAL || tstatus === LOCALLY_SOLVED || tstatus === ALMOST_OPTIMAL ||
331+
tstatus === ALMOST_LOCALLY_SOLVED))
314332
sol = SciMLBase.solution_new_retcode(sol, SciMLBase.ReturnCode.ConvergenceFailure)
315-
!isnothing(input_sol) && (input_sol = SciMLBase.solution_new_retcode(input_sol, SciMLBase.ReturnCode.ConvergenceFailure))
333+
!isnothing(input_sol) && (input_sol = SciMLBase.solution_new_retcode(
334+
input_sol, SciMLBase.ReturnCode.ConvergenceFailure))
316335
end
317336

318337
JuMPControlSolution(model, sol, input_sol)

test/extensions/jump_control.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@ using JuMP, InfiniteOpt
33
using DiffEqDevTools, DiffEqBase
44
using SimpleDiffEq
55
using OrdinaryDiffEqSDIRK
6-
using Ipopt
6+
using Ipopt
77
using BenchmarkTools
88
const M = ModelingToolkit
99

1010
@testset "ODE Solution, no cost" begin
1111
# Test solving without anything attached.
1212
@parameters α=1.5 β=1.0 γ=3.0 δ=1.0
1313
M.@variables x(..) y(..)
14-
t = M.t_nounits; D = M.D_nounits
14+
t = M.t_nounits
15+
D = M.D_nounits
1516

1617
eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y(t),
1718
D(y(t)) ~ -γ * y(t) + δ * x(t) * y(t)]
@@ -34,7 +35,8 @@ const M = ModelingToolkit
3435
osol2 = @btime solve($oprob, ImplicitEuler(), dt = 0.01, adaptive = false) # 129.375 μs, 61.91 KiB
3536
@test (jsol2.sol.u, osol2.u, rtol = 0.001)
3637
iprob = InfiniteOptControlProblem(sys, u0map, tspan, parammap, dt = 0.01)
37-
isol = @btime solve($iprob, Ipopt.Optimizer, derivative_method = FiniteDifference(Backward())) # 11.540 ms, 4.00 MiB
38+
isol = @btime solve(
39+
$iprob, Ipopt.Optimizer, derivative_method = FiniteDifference(Backward())) # 11.540 ms, 4.00 MiB
3840

3941
# With a constraint
4042
u0map = Pair[]
@@ -49,8 +51,10 @@ const M = ModelingToolkit
4951
@test sol(0.6)[1] 3.5
5052
@test sol(0.3)[1] 7.0
5153

52-
iprob = InfiniteOptControlProblem(lksys, u0map, tspan, parammap; guesses = guess, dt = 0.01)
53-
isol = @btime solve($iprob, Ipopt.Optimizer, derivative_method = OrthogonalCollocation(3)) # 48.564 ms, 9.58 MiB
54+
iprob = InfiniteOptControlProblem(
55+
lksys, u0map, tspan, parammap; guesses = guess, dt = 0.01)
56+
isol = @btime solve(
57+
$iprob, Ipopt.Optimizer, derivative_method = OrthogonalCollocation(3)) # 48.564 ms, 9.58 MiB
5458
sol = isol.sol
5559
@test sol(0.6)[1] 3.5
5660
@test sol(0.3)[1] 7.0

0 commit comments

Comments
 (0)