Skip to content

Commit b2d00f6

Browse files
committed
partial: add free final time and bounds-handling
1 parent a686605 commit b2d00f6

File tree

4 files changed

+99
-29
lines changed

4 files changed

+99
-29
lines changed

ext/MTKJuMPControlExt.jl

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,11 @@ function MTK.JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap;
5555
guesses = Dict(), kwargs...)
5656
MTK.warn_overdetermined(sys, u0map)
5757
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
58+
@show _u0map
5859
f, u0, p = MTK.process_SciMLProblem(ODEFunction, sys, _u0map, pmap;
5960
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
61+
62+
(f_i, f_o) = generate_control_function(sys)
6063
model = init_model(sys, tspan[1]:dt:tspan[2], u0map, u0)
6164

6265
JuMPControlProblem(f, u0, tspan, p, model, kwargs...)
@@ -86,13 +89,14 @@ function MTK.InfiniteOptControlProblem(sys::ODESystem, u0map, tspan, pmap;
8689
end
8790

8891
function init_model(sys, tsteps, u0map, u0)
89-
ctrls = controls(sys)
92+
ctrls = MTK.unbound_inputs(sys)
9093
states = unknowns(sys)
9194
model = InfiniteModel()
9295
@infinite_parameter(model, t in [tsteps[1], tsteps[end]], num_supports=length(tsteps))
9396
@variable(model, U[i = 1:length(states)], Infinite(t))
9497
@variable(model, V[1:length(ctrls)], Infinite(t))
9598

99+
set_bounds!(model, sys)
96100
add_jump_cost_function!(model, sys)
97101
add_user_constraints!(model, sys)
98102

@@ -103,6 +107,22 @@ function init_model(sys, tsteps, u0map, u0)
103107
return model
104108
end
105109

110+
function set_bounds!(model, sys)
111+
U = model[:U]
112+
for (i, u) in enumerate(unknowns(sys))
113+
lo, hi = MTK.getbounds(u)
114+
set_lower_bound(U[i], lo)
115+
set_upper_bound(U[i], hi)
116+
end
117+
118+
V = model[:V]
119+
for (i, v) in enumerate(MTK.unbound_inputs(sys))
120+
lo, hi = MTK.getbounds(v)
121+
set_lower_bound(V[i], lo)
122+
set_upper_bound(V[i], hi)
123+
end
124+
end
125+
106126
function add_jump_cost_function!(model::InfiniteModel, sys)
107127
jcosts = MTK.get_costs(sys)
108128
consolidate = MTK.get_consolidate(sys)
@@ -113,7 +133,7 @@ function add_jump_cost_function!(model::InfiniteModel, sys)
113133
iv = MTK.get_iv(sys)
114134

115135
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
116-
cidxmap = Dict([v => i for (i, v) in enumerate(controls(sys))])
136+
cidxmap = Dict([v => i for (i, v) in enumerate(MTK.unbound_inputs(sys))])
117137

118138
for st in unknowns(sys)
119139
x = operation(st)
@@ -123,7 +143,7 @@ function add_jump_cost_function!(model::InfiniteModel, sys)
123143
jcosts = map(c -> Symbolics.substitute(c, Dict(x(t) => subval)), jcosts)
124144
end
125145

126-
for ct in controls(sys)
146+
for ct in MTK.unbound_inputs(sys)
127147
p = operation(ct)
128148
t = only(arguments(ct))
129149
idx = cidxmap[p(iv)]
@@ -141,7 +161,7 @@ function add_user_constraints!(model::InfiniteModel, sys)
141161

142162
iv = MTK.get_iv(sys)
143163
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
144-
cidxmap = Dict([v => i for (i, v) in enumerate(controls(sys))])
164+
cidxmap = Dict([v => i for (i, v) in enumerate(MTK.unbound_inputs(sys))])
145165

146166
for st in unknowns(conssys)
147167
x = operation(st)
@@ -151,7 +171,7 @@ function add_user_constraints!(model::InfiniteModel, sys)
151171
jconstraints = map(c -> Symbolics.substitute(c, Dict(x(t) => subval)), jconstraints)
152172
end
153173

154-
for ct in controls(sys)
174+
for ct in MTK.unbound_inputs(sys)
155175
p = operation(ct)
156176
t = only(arguments(ct))
157177
idx = cidxmap[p(iv)]
@@ -185,9 +205,8 @@ function add_infopt_solve_constraints!(model::InfiniteModel, sys, pmap)
185205
V = model[:V]
186206

187207
stmap = Dict([v => U[i] for (i, v) in enumerate(unknowns(sys))])
188-
ctrlmap = Dict([v => V[i] for (i, v) in enumerate(controls(sys))])
208+
ctrlmap = Dict([v => V[i] for (i, v) in enumerate(MTK.unbound_inputs(sys))])
189209
submap = merge(stmap, ctrlmap, Dict(pmap))
190-
@show submap
191210

192211
# Differential equations
193212
diff_eqs = diff_equations(sys)

src/systems/optimal_control_interface.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,9 @@ function warn_overdetermined(sys, u0map)
1919
@warn "The control problem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The solvers will default to doing a nonlinear least-squares optimization."
2020
end
2121
end
22+
23+
"""
24+
IntegralNorm. When applied to an expression.
25+
"""
26+
struct IntegralNorm end
27+

src/systems/problem_utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,7 @@ function maybe_build_initialization_problem(
640640
t = zero(floatT)
641641
end
642642

643+
@show u0map
643644
initializeprob = ModelingToolkit.InitializationProblem{true, SciMLBase.FullSpecialize}(
644645
sys, t, u0map, pmap; guesses, kwargs...)
645646
if state_values(initializeprob) !== nothing

test/extensions/jump_control.jl

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,27 @@
11
using ModelingToolkit
2-
using JuMP, InfiniteOpt
2+
import JuMP, InfiniteOpt
33
using DiffEqDevTools, DiffEqBase
44
using SimpleDiffEq
55
using OrdinaryDiffEqSDIRK
66
using Ipopt
77
using BenchmarkTools
8+
using CairoMakie
89
const M = ModelingToolkit
910

1011
@testset "ODE Solution, no cost" begin
1112
# Test solving without anything attached.
1213
@parameters α=1.5 β=1.0 γ=3.0 δ=1.0
13-
M.@variables x(..) y(..)
14+
@variables x(..) y(..)
1415
t = M.t_nounits
1516
D = M.D_nounits
1617

1718
eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y(t),
1819
D(y(t)) ~ -γ * y(t) + δ * x(t) * y(t)]
1920

21+
@mtkbuild sys = ODESystem(eqs, t)
2022
tspan = (0.0, 1.0)
2123
u0map = [x(t) => 4.0, y(t) => 2.0]
2224
parammap ==> 1.5, β => 1.0, γ => 3.0, δ => 1.0]
23-
@mtkbuild sys = ODESystem(eqs, t)
2425

2526
# Test explicit method.
2627
jprob = JuMPControlProblem(sys, u0map, tspan, parammap, dt = 0.01)
@@ -58,27 +59,70 @@ const M = ModelingToolkit
5859
sol = isol.sol
5960
@test sol(0.6)[1] 3.5
6061
@test sol(0.3)[1] 7.0
62+
63+
# Test whole-interval constraints
64+
constr = [x(t) > 3, y(t) > 4]
65+
@mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
66+
iprob = InfiniteOptControlProblem(
67+
lksys, u0map, tspan, parammap; guesses = guess, dt = 0.01)
68+
isol = @btime solve(
69+
$iprob, Ipopt.Optimizer, derivative_method = OrthogonalCollocation(3), silent = true) # 48.564 ms, 9.58 MiB
70+
sol = isol.sol
71+
@test all(u -> u .> [3, 4], sol.u)
72+
73+
jprob = JuMPControlProblem(lksys, u0map, tspan, parammap; guesses = guess, dt = 0.01)
74+
jsol = @btime solve($jprob, Ipopt.Optimizer, :RadauIA3, silent = true) # 12.190 s, 9.68 GiB
75+
sol = jsol.sol
76+
@test all(u -> u .> [3, 4], sol.u)
6177
end
6278

63-
#@testset "Optimal control: bees" begin
64-
# # Example from Lawrence Evans' notes
65-
# M.@variables w(..) q(..)
66-
# M.@parameters α(t) [bounds = [0, 1]] b c μ s ν
67-
#
68-
# tspan = (0, 4)
69-
# eqs = [D(w(t)) ~ -μ*w(t) + b*s*α*w(t),
70-
# D(q(t)) ~ -ν*q(t) + c*(1 - α)*s*w(t)]
71-
# costs = [-q(tspan[2])]
72-
#
73-
# @mtkbuild beesys = ODESystem(eqs, t; costs)
74-
# u0map = [w(0) => 40, q(0) => 2]
75-
# pmap = [b => 1, c => 1, μ => 1, s => 1, ν => 1]
76-
#
77-
# jprob = JuMPControlProblem(beesys, u0map, tspan, pmap)
78-
# jsol = solve(jprob, Ipopt.Optimizer, :Tsitouras5)
79-
# control_sol = jsol.control_sol
80-
# # Bang-bang control
81-
#end
79+
@testset "Linear systems" begin
80+
function is_bangbang(input_sol, lbounds, ubounds)
81+
bangbang = true
82+
for v in 1:length(input_sol.u[1])
83+
all(i -> i[v] bounds[v] || i[v] bounds[u], input_sol.u) || (bangbang = false)
84+
end
85+
bangbang
86+
end
87+
88+
# Double integrator
89+
@variables x(..) [bounds = (0., 0.25)] v(..)
90+
@variables u(t) [bounds = (-1., 1.), input = true]
91+
constr = [v(1.0) ~ 0.0]
92+
cost = [-x(1.0)] # Optimize the final distance.
93+
@named block = ODESystem([D(x(t)) ~ v(t), D(v(t)) ~ u], t)
94+
block, input_idxs = structural_simplify(block, ([u],[]))
95+
96+
u0map = [x(t) => 0., v(t) => 0.]
97+
tspan = (0., 1.)
98+
parammap = [u => 0.]
99+
jprob = JuMPControlProblem(block, u0map, tspan, parammap; dt = 0.01)
100+
jsol = solve(jprob, Ipopt.Optimizer, :Verner8)
101+
# Linear systems have bang-bang controls
102+
@test is_bangbang(jsol.input_sol, [-1.], [1.])
103+
# Test reached final position.
104+
@test jsol.sol.u[end][1] 0.25
105+
106+
# Cart-pole system
107+
108+
# Bee example (from Lawrence Evans' notes)
109+
M.@variables w(..) q(..)
110+
M.@parameters α(t) [bounds = [0, 1]] b c μ s ν
111+
112+
tspan = (0, 4)
113+
eqs = [D(w(t)) ~ -μ*w(t) + b*s*α*w(t),
114+
D(q(t)) ~ -ν*q(t) + c*(1 - α)*s*w(t)]
115+
costs = [-q(tspan[2])]
116+
117+
@mtkbuild beesys = ODESystem(eqs, t; costs)
118+
u0map = [w(0) => 40, q(0) => 2]
119+
pmap = [b => 1, c => 1, μ => 1, s => 1, ν => 1]
120+
121+
jprob = JuMPControlProblem(beesys, u0map, tspan, pmap)
122+
jsol = solve(jprob, Ipopt.Optimizer, :Tsitouras5)
123+
control_sol = jsol.control_sol
124+
# Bang-bang control
125+
end
82126
#
83127
#@testset "Constrained optimal control problems" begin
84128
#end

0 commit comments

Comments
 (0)