Skip to content

Commit dc10aa3

Browse files
committed
refactor: add optimal control interface file
1 parent 9647cae commit dc10aa3

File tree

3 files changed

+32
-27
lines changed

3 files changed

+32
-27
lines changed

ext/MTKJuMPControlExt.jl

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
module MTKJuMPControlExt
22
using ModelingToolkit
33
using JuMP, InfiniteOpt
4-
using DiffEqDevTools, DiffEqBase, SciMLBase
4+
using DiffEqDevTools, DiffEqBase
55
using LinearAlgebra
66
const MTK = ModelingToolkit
77

8-
abstract type AbstractOptimalControlProblem{uType, tType, isinplace} <:
9-
SciMLBase.AbstractODEProblem{uType, tType, isinplace} end
10-
118
struct JuMPControlProblem{uType, tType, isinplace, P, F, K} <:
129
AbstractOptimalControlProblem{uType, tType, isinplace}
1310
f::F
@@ -56,6 +53,7 @@ The constraints are:
5653
function MTK.JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap;
5754
dt = error("dt must be provided for JuMPControlProblem."),
5855
guesses = Dict(), kwargs...)
56+
MTK.warn_overdetermined(sys, u0map)
5957
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
6058
f, u0, p = MTK.process_SciMLProblem(ODEFunction, sys, _u0map, pmap;
6159
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
@@ -77,6 +75,7 @@ of the system as derivative constraints, rather than using a solver tableau.
7775
function MTK.InfiniteOptControlProblem(sys::ODESystem, u0map, tspan, pmap;
7876
dt = error("dt must be provided for InfiniteOptControlProblem."),
7977
guesses = Dict(), kwargs...)
78+
MTK.warn_overdetermined(sys, u0map)
8079
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
8180
f, u0, p = MTK.process_SciMLProblem(ODEFunction, sys, _u0map, pmap;
8281
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
@@ -87,12 +86,6 @@ function MTK.InfiniteOptControlProblem(sys::ODESystem, u0map, tspan, pmap;
8786
end
8887

8988
function init_model(sys, tsteps, u0map, u0)
90-
constraintsys = MTK.get_constraintsystem(sys)
91-
if !isnothing(constraintsys)
92-
(length(constraints(constraintsys)) + length(u0map) > length(unknowns(sys))) &&
93-
@warn "The control problem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The solvers will default to doing a nonlinear least-squares optimization."
94-
end
95-
9689
ctrls = controls(sys)
9790
states = unknowns(sys)
9891
model = InfiniteModel()
@@ -110,7 +103,7 @@ function init_model(sys, tsteps, u0map, u0)
110103
return model
111104
end
112105

113-
function add_jump_cost_function!(model, sys)
106+
function add_jump_cost_function!(model::InfiniteModel, sys)
114107
jcosts = MTK.get_costs(sys)
115108
consolidate = MTK.get_consolidate(sys)
116109
if isnothing(jcosts) || isempty(jcosts)
@@ -141,7 +134,7 @@ function add_jump_cost_function!(model, sys)
141134
@objective(model, Min, consolidate(jcosts))
142135
end
143136

144-
function add_user_constraints!(model, sys)
137+
function add_user_constraints!(model::InfiniteModel, sys)
145138
conssys = MTK.get_constraintsystem(sys)
146139
jconstraints = isnothing(conssys) ? nothing : MTK.get_constraints(conssys)
147140
(isnothing(jconstraints) || isempty(jconstraints)) && return nothing
@@ -178,14 +171,14 @@ function add_user_constraints!(model, sys)
178171
end
179172
end
180173

181-
function add_initial_constraints!(model, u0, u0_idxs, ts)
174+
function add_initial_constraints!(model::InfiniteModel, u0, u0_idxs, ts)
182175
U = model[:U]
183176
@constraint(model, initial[i in u0_idxs], U[i](ts)==u0[i])
184177
end
185178

186179
is_explicit(tableau) = tableau isa DiffEqDevTools.ExplicitRKTableau
187180

188-
function add_infopt_solve_constraints!(model, sys, pmap)
181+
function add_infopt_solve_constraints!(model::InfiniteModel, sys, pmap)
189182
iv = MTK.get_iv(sys)
190183
t = model[:t]
191184
U = model[:U]
@@ -257,14 +250,6 @@ function add_jump_solve_constraints!(prob, tableau)
257250
end
258251
end
259252

260-
"""
261-
"""
262-
struct JuMPControlSolution
263-
model::InfiniteModel
264-
sol::ODESolution
265-
input_sol::Union{Nothing, ODESolution}
266-
end
267-
268253
"""
269254
Solve JuMPControlProblem. Arguments:
270255
- prob: a JumpControlProblem
@@ -334,7 +319,7 @@ function _solve(prob::AbstractOptimalControlProblem, jump_solver, solver)
334319
input_sol, SciMLBase.ReturnCode.ConvergenceFailure))
335320
end
336321

337-
JuMPControlSolution(model, sol, input_sol)
322+
OptimalControlSolution(model, sol, input_sol)
338323
end
339324

340325
end

src/ModelingToolkit.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,8 @@ export AnalysisPoint, get_sensitivity_function, get_comp_sensitivity_function,
347347
open_loop
348348
function FMIComponent end
349349

350-
function JuMPControlProblem end
351-
export JuMPControlProblem
352-
function InfiniteOptControlProblem end
353-
export InfiniteOptControlProblem
350+
include("src/systems/optimal_control_interface.jl")
351+
export JuMPControlProblem, InfiniteOptControlProblem, PyomoControlProblem, CasADiControlProblem
352+
export OptimalControlSolution
354353

355354
end # module
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
abstract type AbstractOptimalControlProblem{uType, tType, isinplace} <:
2+
SciMLBase.AbstractODEProblem{uType, tType, isinplace} end
3+
4+
struct OptimalControlSolution
5+
model::Any
6+
sol::ODESolution
7+
input_sol::Union{Nothing, ODESolution}
8+
end
9+
10+
function JuMPControlProblem end
11+
function InfiniteOptControlProblem end
12+
function CasADiControlProblem end
13+
function PyomoControlProblem end
14+
15+
function warn_overdetermined(sys, u0map)
16+
constraintsys = get_constraintsystem(sys)
17+
if !isnothing(constraintsys)
18+
(length(constraints(constraintsys)) + length(u0map) > length(unknowns(sys))) &&
19+
@warn "The control problem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The solvers will default to doing a nonlinear least-squares optimization."
20+
end
21+
end

0 commit comments

Comments
 (0)