Skip to content

Commit 5f14d6f

Browse files
committed
feat: InfiniteOptControlProblem
1 parent 2616802 commit 5f14d6f

File tree

1 file changed

+90
-15
lines changed

1 file changed

+90
-15
lines changed

ext/MTKJuMPControlExt.jl

+90-15
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ using DiffEqDevTools, DiffEqBase, SciMLBase
55
using LinearAlgebra
66
const MTK = ModelingToolkit
77

8-
struct JuMPControlProblem{uType, tType, isinplace, P, F, K} <: SciMLBase.AbstractODEProblem{uType, tType, isinplace}
8+
abstract type AbstractOptimalControlProblem{uType, tType, isinplace} <: SciMLBase.AbstractODEProblem{uType, tType, isinplace} end
9+
10+
struct JuMPControlProblem{uType, tType, isinplace, P, F, K} <: AbstractOptimalControlProblem{uType, tType, isinplace}
911
f::F
1012
u0::uType
1113
tspan::tType
@@ -18,6 +20,19 @@ struct JuMPControlProblem{uType, tType, isinplace, P, F, K} <: SciMLBase.Abstrac
1820
end
1921
end
2022

23+
struct InfiniteOptControlProblem{uType, tType, isinplace, P, F, K} <: SciMLBase.AbstractODEProblem{uType, tType, isinplace}
24+
f::F
25+
u0::uType
26+
tspan::tType
27+
p::P
28+
model::InfiniteModel
29+
kwargs::K
30+
31+
function InfiniteOptControlProblem(f, u0, tspan, p, model; kwargs...)
32+
new{typeof(u0), typeof(tspan), SciMLBase.isinplace(f), typeof(p), typeof(f), typeof(kwargs)}(f, u0, tspan, p, model, kwargs)
33+
end
34+
end
35+
2136
"""
2237
JuMPControlProblem(sys::ODESystem, u0, tspan, p; dt)
2338
@@ -34,24 +49,52 @@ The constraints are:
3449
- The solver constraints that encode the time-stepping used by the solver
3550
"""
3651
function MTK.JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("dt must be provided for JuMPControlProblem."), guesses = Dict(), kwargs...)
37-
ts = tspan[1]
38-
te = tspan[2]
39-
steps = ts:dt:te
40-
ctrls = controls(sys)
41-
states = unknowns(sys)
4252
constraintsys = MTK.get_constraintsystem(sys)
43-
4453
if !isnothing(constraintsys)
4554
(length(constraints(constraintsys)) + length(u0map) > length(states)) &&
46-
@warn "The JuMPControlProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The solvers will default to doing a nonlinear least-squares optimization."
55+
@warn "The control problem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The solvers will default to doing a nonlinear least-squares optimization."
56+
end
57+
58+
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
59+
f, u0, p = MTK.process_SciMLProblem(ODEFunction, sys, _u0map, pmap;
60+
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
61+
model = init_model(sys, tspan[1]:dt:tspan[2], u0map)
62+
63+
JuMPControlProblem(f, u0, tspan, p, model, kwargs...)
64+
end
65+
66+
"""
67+
InfiniteOptControlProblem(sys::ODESystem, u0map, tspan, pmap; dt)
68+
69+
Convert an ODESystem representing an optimal control system into a InfiniteOpt model
70+
for solving using optimization. Must provide `dt` for determining the length
71+
of the interpolation arrays.
72+
73+
Related to `JuMPControlProblem`, but directly adds the differential equations
74+
of the system as derivative constraints, rather than using a solver tableau.
75+
"""
76+
function MTK.InfiniteOptControlProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("dt must be provided for InfiniteOptControlProblem."), guesses = Dict(), kwargs...)
77+
constraintsys = MTK.get_constraintsystem(sys)
78+
if !isnothing(constraintsys)
79+
(length(constraints(constraintsys)) + length(u0map) > length(unknowns(sys))) &&
80+
@warn "The control problem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The solvers will default to doing a nonlinear least-squares optimization."
4781
end
4882

4983
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
5084
f, u0, p = MTK.process_SciMLProblem(ODEFunction, sys, _u0map, pmap;
5185
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
5286

87+
model = init_model(sys, tspan[1]:dt:tspan[2], u0map)
88+
add_infopt_solve_constraints!(model, sys, pmap)
89+
InfiniteOptControlProblem(f, u0, tspan, p, model, kwargs...)
90+
end
91+
92+
function init_model(sys, tsteps, u0map)
93+
ctrls = controls(sys)
94+
states = unknowns(sys)
95+
5396
model = InfiniteModel()
54-
@infinite_parameter(model, t in [ts, te], num_supports = length(steps))
97+
@infinite_parameter(model, t in [tsteps[1], tsteps[end]], num_supports = length(tsteps))
5598
@variable(model, U[i = 1:length(states)], Infinite(t))
5699
@variable(model, V[1:length(ctrls)], Infinite(t))
57100

@@ -61,8 +104,6 @@ function MTK.JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("
61104
stidxmap = Dict([v => i for (i, v) in enumerate(states)])
62105
u0_idxs = has_alg_eqs(sys) ? collect(1:length(states)) : [stidxmap[k] for (k, v) in u0map]
63106
add_initial_constraints!(model, u0, u0_idxs, tspan)
64-
65-
JuMPControlProblem(f, u0, tspan, p, model, kwargs...)
66107
end
67108

68109
function add_jump_cost_function!(model, sys)
@@ -140,7 +181,29 @@ end
140181

141182
is_explicit(tableau) = tableau isa DiffEqDevTools.ExplicitRKTableau
142183

143-
function add_solve_constraints!(prob, tableau)
184+
function add_infopt_solve_constraints!(model, sys, pmap)
185+
iv = get_iv(sys)
186+
t = model[:t]
187+
U = model[:U]
188+
V = model[:V]
189+
190+
stmap = Dict([v => U[i] for (i, v) in enumerate(unknowns(sys))])
191+
ctrlmap = Dict([v => V[i] for (i, v) in enumerate(controls(sys))])
192+
submap = merge(stmap, ctrlmap, pmap)
193+
194+
@register_symbolic _D(x) = (x, t)
195+
# Differential equations
196+
diff_eqs = diff_equations(sys)
197+
diff_eqs = map(e -> Symbolics.substitute(e, submap, Differential(iv) => _D), diff_eqs)
198+
@constraint(model, D[i = 1:length(diff_eqs)], diff_eqs[i].lhs == diff_eqs[i].rhs)
199+
200+
# Algebraic equations
201+
alg_eqs = alg_equations(sys)
202+
alg_eqs = map(e -> Symbolics.substitute(e, submap), alg_eqs)
203+
@constraint(model, A[i = 1:length(alg_eqs)], alg_eqs[i].lhs == alg_eqs[i].rhs)
204+
end
205+
206+
function add_jump_solve_constraints!(prob, tableau)
144207
A = tableau.A
145208
α = tableau.α
146209
c = tableau.c
@@ -202,7 +265,6 @@ function DiffEqBase.solve(prob::JuMPControlProblem, jump_solver, ode_solver::Sym
202265
model = prob.model
203266
tableau_getter = Symbol(:construct, ode_solver)
204267
@eval tableau = $tableau_getter()
205-
ts = supports(model[:t])
206268

207269
# Unregister current solver constraints
208270
for con in all_constraints(model)
@@ -218,23 +280,36 @@ function DiffEqBase.solve(prob::JuMPControlProblem, jump_solver, ode_solver::Sym
218280
end
219281
end
220282
add_solve_constraints!(prob, tableau)
283+
_solve(prob, jump_solver, ode_solver)
284+
end
221285

286+
"""
287+
`derivative_method` kwarg refers to the method used by InfiniteOpt to compute derivatives. The list of possible options can be found at https://infiniteopt.github.io/InfiniteOpt.jl/stable/guide/derivative/. Defaults to FiniteDifference(Backward()).
288+
"""
289+
function DiffEqBase.solve(prob::InfiniteOptControlProblem, jump_solver; derivative_method = InfiniteOpt.FiniteDifference(Backward()))
290+
set_derivative_method(prob.model[:t], derivative_method)
291+
_solve(prob, jump_solver, derivative_method)
292+
end
293+
294+
function _solve(prob::AbstractOptimalControlProblem, jump_solver, solver)
295+
model = prob.model
222296
set_optimizer(model, jump_solver)
223297
optimize!(model)
224298

225299
tstatus = termination_status(model)
226300
pstatus = primal_status(model)
227301
!has_values(model) && error("Model not solvable; please report this to github.com/SciML/ModelingToolkit.jl.")
228302

303+
ts = supports(model[:t])
229304
U_vals = value.(model[:U])
230305
U_vals = [[U_vals[i][j] for i in 1:length(U_vals)] for j in 1:length(ts)]
231-
sol = DiffEqBase.build_solution(prob, ode_solver, ts, U_vals)
306+
sol = DiffEqBase.build_solution(prob, solver, ts, U_vals)
232307

233308
input_sol = nothing
234309
if !isempty(model[:V])
235310
V_vals = value.(model[:V])
236311
V_vals = [[V_vals[i][j] for i in 1:length(V_vals)] for j in 1:length(ts)]
237-
input_sol = DiffEqBase.build_solution(prob, ode_solver, ts, V_vals)
312+
input_sol = DiffEqBase.build_solution(prob, solver, ts, V_vals)
238313
end
239314

240315
if !(pstatus === FEASIBLE_POINT && (tstatus === OPTIMAL || tstatus === LOCALLY_SOLVED || tstatus === ALMOST_OPTIMAL || tstatus === ALMOST_LOCALLY_SOLVED))

0 commit comments

Comments
 (0)