Skip to content

feat: create JuMPControlProblem for optimal control #3549

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 47 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
f8f3ce9
init: JuMPProblem
vyudu Apr 1, 2025
89d2c0e
up
vyudu Apr 3, 2025
477af8b
add solver getter
vyudu Apr 7, 2025
fecc1b6
add test/project
vyudu Apr 7, 2025
35f5f0c
Implement solver tableau
vyudu Apr 7, 2025
7ac9ad5
up
vyudu Apr 7, 2025
89e216a
feat: solver tableau debugging
vyudu Apr 8, 2025
ab233ef
fix: use dt in the constraints
vyudu Apr 8, 2025
7d67731
test: add to runtests
vyudu Apr 8, 2025
3f0721e
remove cassadi file
vyudu Apr 8, 2025
f4be49b
up?
vyudu Apr 8, 2025
484ffda
fix: consolidate method
vyudu Apr 8, 2025
1cd9f9d
feat: InfiniteOptControlProblem
vyudu Apr 9, 2025
2641aa9
add InfiniteOPt dep
vyudu Apr 10, 2025
31724e7
feat: add InfiniteOptControlProblem
vyudu Apr 10, 2025
e679939
fix merge
vyudu Apr 18, 2025
c005848
refactor: add optimal control interface file
vyudu Apr 10, 2025
a4dc022
add set_silent option
vyudu Apr 10, 2025
bec2739
format
vyudu Apr 10, 2025
81d32e1
partial: add free final time and bounds-handling
vyudu Apr 16, 2025
3001542
implement ControlFunction
vyudu Apr 16, 2025
0a1afce
feat: working linear control problems
vyudu Apr 17, 2025
d470df6
feat: free final time problems
vyudu Apr 18, 2025
48cb00a
format
vyudu Apr 18, 2025
b127c73
test: add trasncription tests
vyudu Apr 22, 2025
740accf
init new project for optimal control tests
vyudu Apr 22, 2025
bdc782c
test: more test fixes
vyudu Apr 23, 2025
55c3933
more test fixes
vyudu Apr 25, 2025
7e41540
clean up comments and tests
vyudu Apr 25, 2025
feb9da4
format
vyudu Apr 25, 2025
f388613
rename Control -> DynamicOpt
vyudu Apr 25, 2025
9dcc6b9
test fixes
vyudu Apr 26, 2025
f4d1760
format
vyudu Apr 26, 2025
3c71a87
add JuMP to extensions Project.toml
vyudu Apr 28, 2025
1490586
fix tests
vyudu Apr 28, 2025
4690d5b
move jump control to downsteram
vyudu Apr 28, 2025
eb51290
fix more tests
vyudu Apr 28, 2025
ed2caaf
rtest: emove from extensions
vyudu Apr 28, 2025
15f3e8a
fix: make free final time problems with constraints work
vyudu Apr 29, 2025
ade94d9
format: format
vyudu Apr 29, 2025
16d2986
fix: fix rocket launch test
vyudu Apr 30, 2025
e5a2707
feat: add default ode solver'
vyudu May 3, 2025
5b13745
refactor: merge InfiniteOptExts
vyudu May 3, 2025
b32af3d
test default solver
vyudu May 3, 2025
5813992
remove python dynamicopt problems
vyudu May 3, 2025
dc0dbfa
use Module instead of Main
vyudu May 3, 2025
8773654
remove \int
vyudu May 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
431 changes: 427 additions & 4 deletions ext/MTKInfiniteOptExt.jl

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,4 +348,8 @@ export AnalysisPoint, get_sensitivity_function, get_comp_sensitivity_function,
open_loop
function FMIComponent end

include("systems/optimal_control_interface.jl")
export AbstractDynamicOptProblem, JuMPDynamicOptProblem, InfiniteOptDynamicOptProblem
export DynamicOptSolution

end # module
14 changes: 7 additions & 7 deletions src/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
inputs = [inputs; disturbance_inputs]
end

sys, _ = io_preprocessing(sys, inputs, []; simplify, kwargs...)
if !iscomplete(sys)
sys, _ = io_preprocessing(sys, inputs, []; simplify, kwargs...)
end

dvs = unknowns(sys)
ps = parameters(sys; initial_parameters = true)
Expand Down Expand Up @@ -248,11 +250,9 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
args = (ddvs, args...)
end
f = build_function_wrapper(sys, rhss, args...; p_start = 3 + implicit_dae,
p_end = length(p) + 2 + implicit_dae)
p_end = length(p) + 2 + implicit_dae, kwargs...)
f = eval_or_rgf.(f; eval_expression, eval_module)
f = GeneratedFunctionWrapper{(
3 + implicit_dae, length(args) - length(p) + 1, is_split(sys))}(f...)
f = f, f
f = GeneratedFunctionWrapper{(3, length(args) - length(p) + 1, is_split(sys))}(f...)
ps = setdiff(parameters(sys), inputs, disturbance_inputs)
(; f, dvs, ps, io_sys = sys)
end
Expand Down Expand Up @@ -430,7 +430,7 @@ function add_input_disturbance(sys, dist::DisturbanceModel, inputs = nothing; kw
augmented_sys = ODESystem(eqs, t, systems = [dsys], name = gensym(:outer))
augmented_sys = extend(augmented_sys, sys)

(f_oop, f_ip), dvs, p, io_sys = generate_control_function(augmented_sys, all_inputs,
f, dvs, p, io_sys = generate_control_function(augmented_sys, all_inputs,
[d]; kwargs...)
(f_oop, f_ip), augmented_sys, dvs, p, io_sys
f, augmented_sys, dvs, p, io_sys
end
2 changes: 1 addition & 1 deletion src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ function calculate_control_jacobian(sys::AbstractODESystem;
end

rhs = [eq.rhs for eq in full_equations(sys)]
ctrls = controls(sys)
ctrls = unbound_inputs(sys)

if sparse
jac = sparsejacobian(rhs, ctrls, simplify = simplify)
Expand Down
12 changes: 9 additions & 3 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;

if length(costs) > 1 && isnothing(consolidate)
error("Must specify a consolidation function for the costs vector.")
elseif length(costs) == 1 && isnothing(consolidate)
consolidate = u -> u[1]
end

assertions = Dict{BasicSymbolic, Any}(unwrap(k) => v for (k, v) in assertions)
Expand Down Expand Up @@ -763,6 +765,7 @@ function process_constraint_system(
constraintps = OrderedSet()
for cons in constraints
collect_vars!(constraintsts, constraintps, cons, iv)
union!(constraintsts, collect_applied_operators(cons, Differential))
end

# Validate the states.
Expand Down Expand Up @@ -800,11 +803,14 @@ function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv)

for var in auxvars
if !iscall(var)
occursin(iv, var) && (var ∈ sts ||
throw(ArgumentError("Time-dependent variable $var is not an unknown of the system.")))
var ∈ sts ||
throw(ArgumentError("Time-independent variable $var is not an unknown of the system."))
elseif length(arguments(var)) > 1
throw(ArgumentError("Too many arguments for variable $var."))
elseif length(arguments(var)) == 1
if iscall(var) && operation(var) isa Differential
var = only(arguments(var))
end
arg = only(arguments(var))
operation(var)(iv) ∈ sts ||
throw(ArgumentError("Variable $var is not a variable of the ODESystem. Called variables must be variables of the ODESystem."))
Expand All @@ -813,7 +819,7 @@ function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv)
arg isa AbstractFloat ||
throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds."))

isparameter(arg) && push!(auxps, arg)
(isparameter(arg) && !isequal(arg, iv)) && push!(auxps, arg)
else
var ∈ sts &&
@warn "Variable $var has no argument. It will be interpreted as $var($iv), and the constraint will apply to the entire interval."
Expand Down
162 changes: 162 additions & 0 deletions src/systems/optimal_control_interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
abstract type AbstractDynamicOptProblem{uType, tType, isinplace} <:
SciMLBase.AbstractODEProblem{uType, tType, isinplace} end

struct DynamicOptSolution
model::Any
sol::ODESolution
input_sol::Union{Nothing, ODESolution}
end

function Base.show(io::IO, sol::DynamicOptSolution)
println("retcode: ", sol.sol.retcode, "\n")

println("Optimal control solution for following model:\n")
show(sol.model)

print("\n\nPlease query the model using sol.model, the solution trajectory for the system using sol.sol, or the solution trajectory for the controllers using sol.input_sol.")
end

function JuMPDynamicOptProblem end
function InfiniteOptDynamicOptProblem end

function warn_overdetermined(sys, u0map)
constraintsys = get_constraintsystem(sys)
if !isnothing(constraintsys)
(length(constraints(constraintsys)) + length(u0map) > length(unknowns(sys))) &&
@warn "The control problem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The solvers will default to doing a nonlinear least-squares optimization."
end
end

"""
Generate the control function f(x, u, p, t) from the ODESystem.
Input variables are automatically inferred but can be manually specified.
"""
function SciMLBase.ODEInputFunction{iip, specialize}(sys::ODESystem,
dvs = unknowns(sys),
ps = parameters(sys), u0 = nothing,
inputs = unbound_inputs(sys),
disturbance_inputs = disturbances(sys);
version = nothing, tgrad = false,
jac = false, controljac = false,
p = nothing, t = nothing,
eval_expression = false,
sparse = false, simplify = false,
eval_module = @__MODULE__,
steady_state = false,
checkbounds = false,
sparsity = false,
analytic = nothing,
split_idxs = nothing,
initialization_data = nothing,
cse = true,
kwargs...) where {iip, specialize}
(f), _, _ = generate_control_function(
sys, inputs, disturbance_inputs; eval_module, cse, kwargs...)

if tgrad
tgrad_gen = generate_tgrad(sys, dvs, ps;
simplify = simplify,
expression = Val{true},
expression_module = eval_module, cse,
checkbounds = checkbounds, kwargs...)
tgrad_oop, tgrad_iip = eval_or_rgf.(tgrad_gen; eval_expression, eval_module)
_tgrad = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(tgrad_oop, tgrad_iip)
else
_tgrad = nothing
end

if jac
jac_gen = generate_jacobian(sys, dvs, ps;
simplify = simplify, sparse = sparse,
expression = Val{true},
expression_module = eval_module, cse,
checkbounds = checkbounds, kwargs...)
jac_oop, jac_iip = eval_or_rgf.(jac_gen; eval_expression, eval_module)

_jac = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(jac_oop, jac_iip)
else
_jac = nothing
end

if controljac
cjac_gen = generate_control_jacobian(sys, dvs, ps;
simplify = simplify, sparse = sparse,
expression = Val{true},
expression_module = eval_module, cse,
checkbounds = checkbounds, kwargs...)
cjac_oop, cjac_iip = eval_or_rgf.(cjac_gen; eval_expression, eval_module)

_cjac = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(cjac_oop, cjac_iip)
else
_cjac = nothing
end

M = calculate_massmatrix(sys)
_M = if sparse && !(u0 === nothing || M === I)
SparseArrays.sparse(M)
elseif u0 === nothing || M === I
M
else
ArrayInterface.restructure(u0 .* u0', M)
end

observedfun = ObservedFunctionCache(
sys; steady_state, eval_expression, eval_module, checkbounds, cse)

if sparse
uElType = u0 === nothing ? Float64 : eltype(u0)
W_prototype = similar(W_sparsity(sys), uElType)
controljac_prototype = similar(calculate_control_jacobian(sys), uElType)
else
W_prototype = nothing
controljac_prototype = nothing
end

ODEInputFunction{iip, specialize}(f;
sys = sys,
jac = _jac === nothing ? nothing : _jac,
controljac = _cjac === nothing ? nothing : _cjac,
tgrad = _tgrad === nothing ? nothing : _tgrad,
mass_matrix = _M,
jac_prototype = W_prototype,
controljac_prototype = controljac_prototype,
observed = observedfun,
sparsity = sparsity ? W_sparsity(sys) : nothing,
analytic = analytic,
initialization_data)
end

function SciMLBase.ODEInputFunction(sys::AbstractODESystem, args...; kwargs...)
ODEInputFunction{true}(sys, args...; kwargs...)
end

function SciMLBase.ODEInputFunction{true}(sys::AbstractODESystem, args...;
kwargs...)
ODEInputFunction{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
end

function SciMLBase.ODEInputFunction{false}(sys::AbstractODESystem, args...;
kwargs...)
ODEInputFunction{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
end

# returns the JuMP timespan, the number of steps, and whether it is a free time problem.
function process_tspan(tspan, dt, steps)
is_free_time = false
if isnothing(dt) && isnothing(steps)
error("Must provide either the dt or the number of intervals to the collocation solvers (JuMP, InfiniteOpt, CasADi).")
elseif symbolic_type(tspan[1]) === ScalarSymbolic() ||
symbolic_type(tspan[2]) === ScalarSymbolic()
isnothing(steps) &&
error("Free final time problems require specifying the number of steps using the keyword arg `steps`, rather than dt.")
isnothing(dt) ||
@warn "Specified dt for free final time problem. This will be ignored; dt will be determined by the number of timesteps."

return steps, true
else
isnothing(steps) ||
@warn "Specified number of steps for problem with concrete tspan. This will be ignored; number of steps will be determined by dt."

return length(tspan[1]:dt:tspan[2]), false
end
end
5 changes: 5 additions & 0 deletions src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,11 @@ function hasbounds(x)
any(isfinite.(b[1]) .|| isfinite.(b[2]))
end

function setbounds(x::Num, bounds)
(lb, ub) = bounds
setmetadata(x, VariableBounds, (lb, ub))
end

## Disturbance =================================================================
struct VariableDisturbance end
Symbolics.option_to_metadata_type(::Val{:disturbance}) = VariableDisturbance
Expand Down
13 changes: 12 additions & 1 deletion test/downstream/Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
[deps]
ControlSystemsMTK = "687d7614-c7e5-45fc-bfc3-9ee385575c88"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
OrdinaryDiffEqFIRK = "5960d6e9-dd7a-4743-88e7-cf307b64f125"
OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
OrdinaryDiffEqRosenbrock = "43230ef6-c299-4910-a778-202eb28ce4ce"
OrdinaryDiffEqSDIRK = "2d112036-d095-4a1e-ab9a-08536f3ecdbf"
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
OrdinaryDiffEqVerner = "79d7bb75-1356-48c1-b8c0-6832512096c2"
SimpleDiffEq = "05bca326-078c-5bf0-a5bf-ce7c7982d7fd"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"

[compat]
Expand Down
2 changes: 1 addition & 1 deletion test/downstream/analysis_points.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ModelingToolkit, OrdinaryDiffEq, LinearAlgebra, ControlSystemsBase
using ModelingToolkit, OrdinaryDiffEqRosenbrock, LinearAlgebra, ControlSystemsBase
using ModelingToolkitStandardLibrary.Mechanical.Rotational
using ModelingToolkitStandardLibrary.Blocks
using ModelingToolkit: connect, t_nounits as t, D_nounits as D
Expand Down
2 changes: 1 addition & 1 deletion test/downstream/inversemodel.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using ModelingToolkit
using ModelingToolkitStandardLibrary
using ModelingToolkitStandardLibrary.Blocks
using OrdinaryDiffEq
using OrdinaryDiffEqRosenbrock
using SymbolicIndexingInterface
using Test
using ControlSystemsMTK: tf, ss, get_named_sensitivity, get_named_comp_sensitivity
Expand Down
Loading
Loading