Skip to content

Commit 1d32de6

Browse files
Merge pull request #3549 from vyudu/jump
feat: create JuMPControlProblem for optimal control
2 parents 7a9c17f + fe29a73 commit 1d32de6

19 files changed

+955
-52
lines changed

ext/MTKInfiniteOptExt.jl

Lines changed: 422 additions & 4 deletions
Large diffs are not rendered by default.

src/ModelingToolkit.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,4 +348,8 @@ export AnalysisPoint, get_sensitivity_function, get_comp_sensitivity_function,
348348
open_loop
349349
function FMIComponent end
350350

351+
include("systems/optimal_control_interface.jl")
352+
export AbstractDynamicOptProblem, JuMPDynamicOptProblem, InfiniteOptDynamicOptProblem
353+
export DynamicOptSolution
354+
351355
end # module

src/inputoutput.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,9 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
208208
inputs = [inputs; disturbance_inputs]
209209
end
210210

211-
sys, _ = io_preprocessing(sys, inputs, []; simplify, kwargs...)
211+
if !iscomplete(sys)
212+
sys, _ = io_preprocessing(sys, inputs, []; simplify, kwargs...)
213+
end
212214

213215
dvs = unknowns(sys)
214216
ps = parameters(sys; initial_parameters = true)
@@ -248,11 +250,9 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
248250
args = (ddvs, args...)
249251
end
250252
f = build_function_wrapper(sys, rhss, args...; p_start = 3 + implicit_dae,
251-
p_end = length(p) + 2 + implicit_dae)
253+
p_end = length(p) + 2 + implicit_dae, kwargs...)
252254
f = eval_or_rgf.(f; eval_expression, eval_module)
253-
f = GeneratedFunctionWrapper{(
254-
3 + implicit_dae, length(args) - length(p) + 1, is_split(sys))}(f...)
255-
f = f, f
255+
f = GeneratedFunctionWrapper{(3, length(args) - length(p) + 1, is_split(sys))}(f...)
256256
ps = setdiff(parameters(sys), inputs, disturbance_inputs)
257257
(; f, dvs, ps, io_sys = sys)
258258
end
@@ -430,7 +430,7 @@ function add_input_disturbance(sys, dist::DisturbanceModel, inputs = nothing; kw
430430
augmented_sys = ODESystem(eqs, t, systems = [dsys], name = gensym(:outer))
431431
augmented_sys = extend(augmented_sys, sys)
432432

433-
(f_oop, f_ip), dvs, p, io_sys = generate_control_function(augmented_sys, all_inputs,
433+
f, dvs, p, io_sys = generate_control_function(augmented_sys, all_inputs,
434434
[d]; kwargs...)
435-
(f_oop, f_ip), augmented_sys, dvs, p, io_sys
435+
f, augmented_sys, dvs, p, io_sys
436436
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ function calculate_control_jacobian(sys::AbstractODESystem;
101101
end
102102

103103
rhs = [eq.rhs for eq in full_equations(sys)]
104-
ctrls = controls(sys)
104+
ctrls = unbound_inputs(sys)
105105

106106
if sparse
107107
jac = sparsejacobian(rhs, ctrls, simplify = simplify)

src/systems/diffeqs/odesystem.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
354354

355355
if length(costs) > 1 && isnothing(consolidate)
356356
error("Must specify a consolidation function for the costs vector.")
357+
elseif length(costs) == 1 && isnothing(consolidate)
358+
consolidate = u -> u[1]
357359
end
358360

359361
assertions = Dict{BasicSymbolic, Any}(unwrap(k) => v for (k, v) in assertions)
@@ -763,6 +765,7 @@ function process_constraint_system(
763765
constraintps = OrderedSet()
764766
for cons in constraints
765767
collect_vars!(constraintsts, constraintps, cons, iv)
768+
union!(constraintsts, collect_applied_operators(cons, Differential))
766769
end
767770

768771
# Validate the states.
@@ -800,11 +803,14 @@ function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv)
800803

801804
for var in auxvars
802805
if !iscall(var)
803-
occursin(iv, var) && (var sts ||
804-
throw(ArgumentError("Time-dependent variable $var is not an unknown of the system.")))
806+
var sts ||
807+
throw(ArgumentError("Time-independent variable $var is not an unknown of the system."))
805808
elseif length(arguments(var)) > 1
806809
throw(ArgumentError("Too many arguments for variable $var."))
807810
elseif length(arguments(var)) == 1
811+
if iscall(var) && operation(var) isa Differential
812+
var = only(arguments(var))
813+
end
808814
arg = only(arguments(var))
809815
operation(var)(iv) sts ||
810816
throw(ArgumentError("Variable $var is not a variable of the ODESystem. Called variables must be variables of the ODESystem."))
@@ -813,7 +819,7 @@ function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv)
813819
arg isa AbstractFloat ||
814820
throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds."))
815821

816-
isparameter(arg) && push!(auxps, arg)
822+
(isparameter(arg) && !isequal(arg, iv)) && push!(auxps, arg)
817823
else
818824
var sts &&
819825
@warn "Variable $var has no argument. It will be interpreted as $var($iv), and the constraint will apply to the entire interval."
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
abstract type AbstractDynamicOptProblem{uType, tType, isinplace} <:
2+
SciMLBase.AbstractODEProblem{uType, tType, isinplace} end
3+
4+
struct DynamicOptSolution
5+
model::Any
6+
sol::ODESolution
7+
input_sol::Union{Nothing, ODESolution}
8+
end
9+
10+
function Base.show(io::IO, sol::DynamicOptSolution)
11+
println("retcode: ", sol.sol.retcode, "\n")
12+
13+
println("Optimal control solution for following model:\n")
14+
show(sol.model)
15+
16+
print("\n\nPlease query the model using sol.model, the solution trajectory for the system using sol.sol, or the solution trajectory for the controllers using sol.input_sol.")
17+
end
18+
19+
function JuMPDynamicOptProblem end
20+
function InfiniteOptDynamicOptProblem end
21+
22+
function warn_overdetermined(sys, u0map)
23+
constraintsys = get_constraintsystem(sys)
24+
if !isnothing(constraintsys)
25+
(length(constraints(constraintsys)) + length(u0map) > length(unknowns(sys))) &&
26+
@warn "The control problem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The solvers will default to doing a nonlinear least-squares optimization."
27+
end
28+
end
29+
30+
"""
31+
Generate the control function f(x, u, p, t) from the ODESystem.
32+
Input variables are automatically inferred but can be manually specified.
33+
"""
34+
function SciMLBase.ODEInputFunction{iip, specialize}(sys::ODESystem,
35+
dvs = unknowns(sys),
36+
ps = parameters(sys), u0 = nothing,
37+
inputs = unbound_inputs(sys),
38+
disturbance_inputs = disturbances(sys);
39+
version = nothing, tgrad = false,
40+
jac = false, controljac = false,
41+
p = nothing, t = nothing,
42+
eval_expression = false,
43+
sparse = false, simplify = false,
44+
eval_module = @__MODULE__,
45+
steady_state = false,
46+
checkbounds = false,
47+
sparsity = false,
48+
analytic = nothing,
49+
split_idxs = nothing,
50+
initialization_data = nothing,
51+
cse = true,
52+
kwargs...) where {iip, specialize}
53+
(f), _, _ = generate_control_function(
54+
sys, inputs, disturbance_inputs; eval_module, cse, kwargs...)
55+
56+
if tgrad
57+
tgrad_gen = generate_tgrad(sys, dvs, ps;
58+
simplify = simplify,
59+
expression = Val{true},
60+
expression_module = eval_module, cse,
61+
checkbounds = checkbounds, kwargs...)
62+
tgrad_oop, tgrad_iip = eval_or_rgf.(tgrad_gen; eval_expression, eval_module)
63+
_tgrad = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(tgrad_oop, tgrad_iip)
64+
else
65+
_tgrad = nothing
66+
end
67+
68+
if jac
69+
jac_gen = generate_jacobian(sys, dvs, ps;
70+
simplify = simplify, sparse = sparse,
71+
expression = Val{true},
72+
expression_module = eval_module, cse,
73+
checkbounds = checkbounds, kwargs...)
74+
jac_oop, jac_iip = eval_or_rgf.(jac_gen; eval_expression, eval_module)
75+
76+
_jac = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(jac_oop, jac_iip)
77+
else
78+
_jac = nothing
79+
end
80+
81+
if controljac
82+
cjac_gen = generate_control_jacobian(sys, dvs, ps;
83+
simplify = simplify, sparse = sparse,
84+
expression = Val{true},
85+
expression_module = eval_module, cse,
86+
checkbounds = checkbounds, kwargs...)
87+
cjac_oop, cjac_iip = eval_or_rgf.(cjac_gen; eval_expression, eval_module)
88+
89+
_cjac = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(cjac_oop, cjac_iip)
90+
else
91+
_cjac = nothing
92+
end
93+
94+
M = calculate_massmatrix(sys)
95+
_M = if sparse && !(u0 === nothing || M === I)
96+
SparseArrays.sparse(M)
97+
elseif u0 === nothing || M === I
98+
M
99+
else
100+
ArrayInterface.restructure(u0 .* u0', M)
101+
end
102+
103+
observedfun = ObservedFunctionCache(
104+
sys; steady_state, eval_expression, eval_module, checkbounds, cse)
105+
106+
if sparse
107+
uElType = u0 === nothing ? Float64 : eltype(u0)
108+
W_prototype = similar(W_sparsity(sys), uElType)
109+
controljac_prototype = similar(calculate_control_jacobian(sys), uElType)
110+
else
111+
W_prototype = nothing
112+
controljac_prototype = nothing
113+
end
114+
115+
ODEInputFunction{iip, specialize}(f;
116+
sys = sys,
117+
jac = _jac === nothing ? nothing : _jac,
118+
controljac = _cjac === nothing ? nothing : _cjac,
119+
tgrad = _tgrad === nothing ? nothing : _tgrad,
120+
mass_matrix = _M,
121+
jac_prototype = W_prototype,
122+
controljac_prototype = controljac_prototype,
123+
observed = observedfun,
124+
sparsity = sparsity ? W_sparsity(sys) : nothing,
125+
analytic = analytic,
126+
initialization_data)
127+
end
128+
129+
function SciMLBase.ODEInputFunction(sys::AbstractODESystem, args...; kwargs...)
130+
ODEInputFunction{true}(sys, args...; kwargs...)
131+
end
132+
133+
function SciMLBase.ODEInputFunction{true}(sys::AbstractODESystem, args...;
134+
kwargs...)
135+
ODEInputFunction{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
136+
end
137+
138+
function SciMLBase.ODEInputFunction{false}(sys::AbstractODESystem, args...;
139+
kwargs...)
140+
ODEInputFunction{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
141+
end
142+
143+
# returns the JuMP timespan, the number of steps, and whether it is a free time problem.
144+
function process_tspan(tspan, dt, steps)
145+
is_free_time = false
146+
if isnothing(dt) && isnothing(steps)
147+
error("Must provide either the dt or the number of intervals to the collocation solvers (JuMP, InfiniteOpt, CasADi).")
148+
elseif symbolic_type(tspan[1]) === ScalarSymbolic() ||
149+
symbolic_type(tspan[2]) === ScalarSymbolic()
150+
isnothing(steps) &&
151+
error("Free final time problems require specifying the number of steps using the keyword arg `steps`, rather than dt.")
152+
isnothing(dt) ||
153+
@warn "Specified dt for free final time problem. This will be ignored; dt will be determined by the number of timesteps."
154+
155+
return steps, true
156+
else
157+
isnothing(steps) ||
158+
@warn "Specified number of steps for problem with concrete tspan. This will be ignored; number of steps will be determined by dt."
159+
160+
return length(tspan[1]:dt:tspan[2]), false
161+
end
162+
end

src/variables.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,11 @@ function hasbounds(x)
332332
any(isfinite.(b[1]) .|| isfinite.(b[2]))
333333
end
334334

335+
function setbounds(x::Num, bounds)
336+
(lb, ub) = bounds
337+
setmetadata(x, VariableBounds, (lb, ub))
338+
end
339+
335340
## Disturbance =================================================================
336341
struct VariableDisturbance end
337342
Symbolics.option_to_metadata_type(::Val{:disturbance}) = VariableDisturbance

test/downstream/Project.toml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
11
[deps]
22
ControlSystemsMTK = "687d7614-c7e5-45fc-bfc3-9ee385575c88"
3+
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
4+
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
5+
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
6+
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
7+
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
38
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
49
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
510
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
6-
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
11+
OrdinaryDiffEqFIRK = "5960d6e9-dd7a-4743-88e7-cf307b64f125"
12+
OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
13+
OrdinaryDiffEqRosenbrock = "43230ef6-c299-4910-a778-202eb28ce4ce"
14+
OrdinaryDiffEqSDIRK = "2d112036-d095-4a1e-ab9a-08536f3ecdbf"
15+
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
16+
OrdinaryDiffEqVerner = "79d7bb75-1356-48c1-b8c0-6832512096c2"
17+
SimpleDiffEq = "05bca326-078c-5bf0-a5bf-ce7c7982d7fd"
718
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
819

920
[compat]

test/downstream/analysis_points.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit, OrdinaryDiffEq, LinearAlgebra, ControlSystemsBase
1+
using ModelingToolkit, OrdinaryDiffEqRosenbrock, LinearAlgebra, ControlSystemsBase
22
using ModelingToolkitStandardLibrary.Mechanical.Rotational
33
using ModelingToolkitStandardLibrary.Blocks
44
using ModelingToolkit: connect, t_nounits as t, D_nounits as D

test/downstream/inversemodel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using ModelingToolkit
22
using ModelingToolkitStandardLibrary
33
using ModelingToolkitStandardLibrary.Blocks
4-
using OrdinaryDiffEq
4+
using OrdinaryDiffEqRosenbrock
55
using SymbolicIndexingInterface
66
using Test
77
using ControlSystemsMTK: tf, ss, get_named_sensitivity, get_named_comp_sensitivity

0 commit comments

Comments
 (0)