Skip to content

Commit 7f86e85

Browse files
committed
feat: working linear control problems
1 parent 6c53436 commit 7f86e85

File tree

4 files changed

+130
-84
lines changed

4 files changed

+130
-84
lines changed

ext/MTKJuMPControlExt.jl

Lines changed: 48 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function MTK.JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap;
5858
f, u0, p = MTK.process_SciMLProblem(ControlFunction, sys, _u0map, pmap;
5959
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
6060

61-
model = init_model(sys, tspan[1]:dt:tspan[2], u0map, u0)
61+
model = init_model(sys, tspan[1]:dt:tspan[2], u0map, pmap, u0)
6262
JuMPControlProblem(f, u0, tspan, p, model, kwargs...)
6363
end
6464

@@ -80,22 +80,23 @@ function MTK.InfiniteOptControlProblem(sys::ODESystem, u0map, tspan, pmap;
8080
f, u0, p = MTK.process_SciMLProblem(ControlFunction, sys, _u0map, pmap;
8181
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
8282

83-
model = init_model(sys, tspan[1]:dt:tspan[2], u0map, u0)
83+
model = init_model(sys, tspan[1]:dt:tspan[2], u0map, pmap, u0)
8484
add_infopt_solve_constraints!(model, sys, pmap)
8585
InfiniteOptControlProblem(f, u0, tspan, p, model, kwargs...)
8686
end
8787

88-
function init_model(sys, tsteps, u0map, u0)
88+
function init_model(sys, tsteps, u0map, pmap, u0)
8989
ctrls = MTK.unbound_inputs(sys)
9090
states = unknowns(sys)
9191
model = InfiniteModel()
92+
9293
@infinite_parameter(model, t in [tsteps[1], tsteps[end]], num_supports=length(tsteps))
9394
@variable(model, U[i = 1:length(states)], Infinite(t))
9495
@variable(model, V[1:length(ctrls)], Infinite(t))
9596

9697
set_bounds!(model, sys)
97-
add_jump_cost_function!(model, sys)
98-
add_user_constraints!(model, sys)
98+
add_jump_cost_function!(model, sys, (tsteps[1], tsteps[2]), pmap)
99+
add_user_constraints!(model, sys, pmap)
99100

100101
stidxmap = Dict([v => i for (i, v) in enumerate(states)])
101102
u0_idxs = has_alg_eqs(sys) ? collect(1:length(states)) :
@@ -120,63 +121,35 @@ function set_bounds!(model, sys)
120121
end
121122
end
122123

123-
function add_jump_cost_function!(model::InfiniteModel, sys)
124+
function add_jump_cost_function!(model::InfiniteModel, sys, tspan, pmap)
124125
jcosts = MTK.get_costs(sys)
125126
consolidate = MTK.get_consolidate(sys)
126127
if isnothing(jcosts) || isempty(jcosts)
127128
@objective(model, Min, 0)
128129
return
129130
end
130-
iv = MTK.get_iv(sys)
131-
132-
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
133-
cidxmap = Dict([v => i for (i, v) in enumerate(MTK.unbound_inputs(sys))])
134-
135-
for st in unknowns(sys)
136-
x = operation(st)
137-
t = only(arguments(st))
138-
idx = stidxmap[x(iv)]
139-
subval = isequal(t, iv) ? model[:U][idx] : model[:U][idx](t)
140-
jcosts = map(c -> Symbolics.substitute(c, Dict(x(t) => subval)), jcosts)
141-
end
131+
jcosts = substitute_jump_vars(model, sys, pmap, jcosts)
142132

143-
for ct in MTK.unbound_inputs(sys)
144-
p = operation(ct)
145-
t = only(arguments(ct))
146-
idx = cidxmap[p(iv)]
147-
subval = isequal(t, iv) ? model[:V][idx] : model[:V][idx](t)
148-
jcosts = map(c -> Symbolics.substitute(c, Dict(p(t) => subval)), jcosts)
133+
# Substitute integral
134+
iv = MTK.get_iv(sys)
135+
jcosts = map(c -> Symbolics.substitute(c, ∫ => Symbolics.Integral(iv in tspan)), jcosts)
136+
intmap = Dict()
137+
138+
for int in MTK.collect_applied_operators(jcosts, Symbolics.Integral)
139+
arg = only(arguments(MTK.value(int)))
140+
lower_bound, upper_bound = (int.domain.domain.left, int.domain.domain.right)
141+
intmap[int] = InfiniteOpt.(arg, iv; lower_bound, upper_bound)
149142
end
150-
143+
jcosts = map(c -> Symbolics.substitute(c, intmap), jcosts)
151144
@objective(model, Min, consolidate(jcosts))
152145
end
153146

154-
function add_user_constraints!(model::InfiniteModel, sys)
147+
function add_user_constraints!(model::InfiniteModel, sys, pmap)
155148
conssys = MTK.get_constraintsystem(sys)
156149
jconstraints = isnothing(conssys) ? nothing : MTK.get_constraints(conssys)
157150
(isnothing(jconstraints) || isempty(jconstraints)) && return nothing
158151

159-
iv = MTK.get_iv(sys)
160-
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
161-
cidxmap = Dict([v => i for (i, v) in enumerate(MTK.unbound_inputs(sys))])
162-
163-
for st in unknowns(conssys)
164-
x = operation(st)
165-
t = only(arguments(st))
166-
idx = stidxmap[x(iv)]
167-
subval = isequal(t, iv) ? model[:U][idx] : model[:U][idx](t)
168-
jconstraints = map(c -> Symbolics.substitute(c, Dict(x(t) => subval)), jconstraints)
169-
end
170-
171-
for ct in MTK.unbound_inputs(sys)
172-
p = operation(ct)
173-
t = only(arguments(ct))
174-
idx = cidxmap[p(iv)]
175-
subval = isequal(t, iv) ? model[:V][idx] : model[:V][idx](t)
176-
jconstraints = map(
177-
c -> Symbolics.substitute(jconstraints, Dict(p(t) => subval)), jconstriants)
178-
end
179-
152+
jconstraints = substitute_jump_vars(model, sys, pmap, jconstraints)
180153
for (i, cons) in enumerate(jconstraints)
181154
if cons isa Equation
182155
@constraint(model, cons.lhs - cons.rhs==0, base_name="user[$i]")
@@ -193,31 +166,41 @@ function add_initial_constraints!(model::InfiniteModel, u0, u0_idxs, ts)
193166
@constraint(model, initial[i in u0_idxs], U[i](ts)==u0[i])
194167
end
195168

196-
is_explicit(tableau) = tableau isa DiffEqDevTools.ExplicitRKTableau
197-
198-
function add_infopt_solve_constraints!(model::InfiniteModel, sys, pmap)
169+
function substitute_jump_vars(model, sys, pmap, exprs)
199170
iv = MTK.get_iv(sys)
200-
t = model[:t]
171+
sts = unknowns(sys)
172+
cts = MTK.unbound_inputs(sys)
201173
U = model[:U]
202174
V = model[:V]
175+
# for variables like x(t)
176+
whole_interval_map = Dict([[v => U[i] for (i, v) in enumerate(sts)]; [v => V[i] for (i, v) in enumerate(cts)]])
177+
exprs = map(c -> Symbolics.substitute(c, whole_interval_map), exprs)
178+
179+
# for variables like x(1.0)
180+
x_ops = [MTK.operation(MTK.unwrap(st)) for st in sts]
181+
c_ops = [MTK.operation(MTK.unwrap(ct)) for ct in cts]
182+
fixed_t_map = Dict([[x_ops[i] => U[i] for i in 1:length(U)]; [c_ops[i] => V[i] for i in 1:length(V)]])
183+
exprs = map(c -> Symbolics.substitute(c, fixed_t_map), exprs)
184+
185+
exprs = map(c -> Symbolics.substitute(c, Dict(pmap)), exprs)
186+
exprs
187+
end
203188

204-
stmap = Dict([v => U[i] for (i, v) in enumerate(unknowns(sys))])
205-
ctrlmap = Dict([v => V[i] for (i, v) in enumerate(MTK.unbound_inputs(sys))])
206-
submap = merge(stmap, ctrlmap, Dict(pmap))
189+
is_explicit(tableau) = tableau isa DiffEqDevTools.ExplicitRKTableau
207190

191+
function add_infopt_solve_constraints!(model::InfiniteModel, sys, pmap)
208192
# Differential equations
209-
diff_eqs = diff_equations(sys)
210-
D = Differential(iv)
193+
U = model[:U]
194+
t = model[:t]
195+
D = Differential(MTK.get_iv(sys))
211196
diffsubmap = Dict([D(U[i]) => (U[i], t) for i in 1:length(U)])
212-
for u in unknowns(sys)
213-
diff_eqs = map(e -> Symbolics.substitute(e, submap), diff_eqs)
214-
diff_eqs = map(e -> Symbolics.substitute(e, diffsubmap), diff_eqs)
215-
end
197+
198+
diff_eqs = substitute_jump_vars(model, sys, pmap, diff_equations(sys))
199+
diff_eqs = map(e -> Symbolics.substitute(e, diffsubmap), diff_eqs)
216200
@constraint(model, D[i = 1:length(diff_eqs)], diff_eqs[i].lhs==diff_eqs[i].rhs)
217201

218202
# Algebraic equations
219-
alg_eqs = alg_equations(sys)
220-
alg_eqs = map(e -> Symbolics.substitute(e, submap), alg_eqs)
203+
alg_eqs = substitute_jump_vars(model, sys, pmap, alg_equations(sys))
221204
@constraint(model, A[i = 1:length(alg_eqs)], alg_eqs[i].lhs==alg_eqs[i].rhs)
222205
end
223206

@@ -306,9 +289,10 @@ end
306289
`derivative_method` kwarg refers to the method used by InfiniteOpt to compute derivatives. The list of possible options can be found at https://infiniteopt.github.io/InfiniteOpt.jl/stable/guide/derivative/. Defaults to FiniteDifference(Backward()).
307290
"""
308291
function DiffEqBase.solve(prob::InfiniteOptControlProblem, jump_solver;
309-
derivative_method = InfiniteOpt.FiniteDifference(Backward()))
292+
derivative_method = InfiniteOpt.FiniteDifference(Backward()), silent = false)
293+
model = prob.model
310294
silent && set_silent(model)
311-
set_derivative_method(prob.model[:t], derivative_method)
295+
set_derivative_method(model[:t], derivative_method)
312296
_solve(prob, jump_solver, derivative_method)
313297
end
314298

src/systems/optimal_control_interface.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,15 @@ function SciMLBase.ControlFunction{false}(sys::AbstractODESystem, args...;
134134
end
135135

136136
"""
137-
IntegralNorm. When applied to an expression.
137+
IntegralNorm. When applied to an expression in a cost
138+
function, assumes that the integration variable is the
139+
iv of the system, and assumes that the bounds are the
140+
tspan.
141+
Equivalent to Integral(t in tspan) in Symbolics.
138142
"""
139-
struct IntegralNorm end
143+
struct<: Symbolics.Operator end
144+
(x) = ()(x)
145+
Base.show(io::IO, x::∫) = print(io, "")
140146

141147
"""
142148
$(SIGNATURES)

src/systems/problem_utils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,6 @@ function maybe_build_initialization_problem(
640640
t = zero(floatT)
641641
end
642642

643-
@show u0map
644643
initializeprob = ModelingToolkit.InitializationProblem{true, SciMLBase.FullSpecialize}(
645644
sys, t, u0map, pmap; guesses, kwargs...)
646645
if state_values(initializeprob) !== nothing

test/extensions/jump_control.jl

Lines changed: 74 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ const M = ModelingToolkit
7777
end
7878

7979
@testset "Linear systems" begin
80-
function is_bangbang(input_sol, lbounds, ubounds)
80+
function is_bangbang(input_sol, lbounds, ubounds, rtol = 1e-4)
8181
bangbang = true
82-
for v in 1:length(input_sol.u[1])
83-
all(i -> i[v] bounds[v] || i[v] bounds[u], input_sol.u) || (bangbang = false)
82+
for v in 1:length(input_sol.u[1]) - 1
83+
all(i -> (i[v], bounds[v]; rtol) || (i[v], bounds[u]; rtol), input_sol.u) || (bangbang = false)
8484
end
8585
bangbang
8686
end
@@ -91,8 +91,8 @@ end
9191
@variables x(..) [bounds = (0., 0.25)] v(..)
9292
@variables u(t) [bounds = (-1., 1.), input = true]
9393
constr = [v(1.0) ~ 0.0]
94-
cost = [-x(1.0)] # Optimize the final distance.
95-
@named block = ODESystem([D(x(t)) ~ v(t), D(v(t)) ~ u], t)
94+
cost = [-x(1.0)] # Maximize the final distance.
95+
@named block = ODESystem([D(x(t)) ~ v(t), D(v(t)) ~ u], t; costs = cost, constraints = constr)
9696
block, input_idxs = structural_simplify(block, ([u],[]))
9797

9898
u0map = [x(t) => 0., v(t) => 0.]
@@ -103,28 +103,85 @@ end
103103
# Linear systems have bang-bang controls
104104
@test is_bangbang(jsol.input_sol, [-1.], [1.])
105105
# Test reached final position.
106-
@test jsol.sol.u[end][1] 0.25
106+
@test (jsol.sol.u[end][1], 0.25, rtol = 1e-5)
107107

108-
# Cart-pole system
108+
iprob = InfiniteOptControlProblem(block, u0map, tspan, parammap; dt = 0.01)
109+
isol = solve(iprob, Ipopt.Optimizer; silent = true)
110+
@test is_bangbang(isol.input_sol, [-1.], [1.])
111+
@test (isol.sol.u[end][1], 0.25, rtol = 1e-5)
109112

110-
# Bee example (from Lawrence Evans' notes)
111-
@variables w(..) q(..)
112-
@parameters α(t) [bounds = [0, 1]] b c μ s ν
113+
###################
114+
### Bee example ###
115+
###################
116+
# From Lawrence Evans' notes
117+
@variables w(..) q(..) α(t) [input = true, bounds = (0, 1)]
118+
@parameters b c μ s ν
113119

114120
tspan = (0, 4)
115121
eqs = [D(w(t)) ~ -μ*w(t) + b*s*α*w(t),
116122
D(q(t)) ~ -ν*q(t) + c*(1 - α)*s*w(t)]
117123
costs = [-q(tspan[2])]
118124

119-
@mtkbuild beesys = ODESystem(eqs, t; costs)
120-
u0map = [w(0) => 40, q(0) => 2]
121-
pmap = [b => 1, c => 1, μ => 1, s => 1, ν => 1]
125+
@named beesys = ODESystem(eqs, t; costs)
126+
beesys, input_idxs = structural_simplify(beesys, ([α],[]))
127+
u0map = [w(t) => 40, q(t) => 2]
128+
pmap = [b => 1, c => 1, μ => 1, s => 1, ν => 1, α => 1]
122129

123-
jprob = JuMPControlProblem(beesys, u0map, tspan, pmap)
130+
jprob = JuMPControlProblem(beesys, u0map, tspan, pmap, dt = 0.01)
124131
jsol = solve(jprob, Ipopt.Optimizer, :Tsitouras5)
125-
control_sol = jsol.control_sol
126-
# Bang-bang control
132+
@test is_bangbang(jsol.input_sol, [0.], [1.])
133+
iprob = InfiniteOptControlProblem(beesys, u0map, tspan, pmap, dt = 0.01)
134+
isol = solve(jprob, Ipopt.Optimizer, :Tsitouras5)
135+
@test is_bangbang(isol.input_sol, [0.], [1.])
127136
end
128-
#
137+
138+
@testset "Rocket launch" begin
139+
t = M.t_nounits
140+
D = M.D_nounits
141+
142+
@variables h(..) v(..) m(..) T(..) [input = true, bounds = (0, tₘ)]
143+
@parameters h_c m₀ h₀ g₀ D_c c Tₘ
144+
@parameters tf
145+
drag(h, v) = D_c * v^2 * exp(-h_c * (h - h₀) / h₀)
146+
gravity(h) = g₀ * (h₀ / h)
147+
148+
eqs = [D(h(t)) ~ v(t),
149+
D(v(t)) ~ (T(t) - drag(h(t), v(t))) / m(t) - gravity(t),
150+
D(m(t)) ~ -T(t) / c]
151+
152+
costs = [-h(tf)]
153+
constraints = [T(tf) ~ 0]
154+
@named rocket = ODESystem(eqs, t; costs, constraints)
155+
@test tf Set(parameters(rocket))
156+
157+
u0map = [h(t) => h₀, m(t) => m₀, v(t) => 0]
158+
pmap = [g₀ => 1, m₀ => 1.0, h_c => 500, c => 0.5*√(g₀*h₀), D_C => 0.5 * 620 * m₀/g₀, Tₘ => 3.5*g₀*m₀]
159+
jprob = JuMPControlProblem(rocket, u0map, (0, tf), pmap)
160+
jsol = solve(jprob, Ipopt.Optimizer, :RadauIA3)
161+
@test jsol.sol.u[end][1] 1.012
162+
end
163+
164+
@testset "Free final time problem" begin
165+
t = M.t_nounits
166+
D = M.D_nounits
167+
168+
@variables x(..) u(..) [input = true, bounds = (0,1)]
169+
@parameters tf
170+
eqs = [D(x(t)) ~ -2 + 0.5*u]
171+
172+
# Integral cost function
173+
costs = [(x-u), x(tf)]
174+
consolidate(u) = u[1] + u[2]
175+
jprob = JuMPControlProblem(rocket, u0map, (0, tf), pmap)
176+
jsol = solve(jprob, Ipopt.Optimizer, :RadauIA3)
177+
@test jsol.sol.t[end] 10.0
178+
iprob = InfiniteOptControlProblem(rocket, u0map, (0, tf), pmap)
179+
isol = solve(iprob, Ipopt.Optimizer, :RadauIA3)
180+
@test isol.sol.t[end] 10.0
181+
end
182+
183+
@testset "Cart-pole problem" begin
184+
end
185+
129186
#@testset "Constrained optimal control problems" begin
130187
#end

0 commit comments

Comments
 (0)