Skip to content

Commit afeb698

Browse files
committed
up?
1 parent 5415bec commit afeb698

File tree

3 files changed

+55
-30
lines changed

3 files changed

+55
-30
lines changed

ext/MTKJuMPControlExt.jl

+27-20
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ The constraints are:
3333
- The set of user constraints passed to the ODESystem via `constraints`
3434
- The solver constraints that encode the time-stepping used by the solver
3535
"""
36-
function MTK.JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("dt must be provided for JuMPControlProblem."), kwargs...)
36+
function MTK.JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("dt must be provided for JuMPControlProblem."), guesses = Dict(), kwargs...)
3737
ts = tspan[1]
3838
te = tspan[2]
3939
steps = ts:dt:te
@@ -42,11 +42,12 @@ function MTK.JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("
4242
constraintsys = MTK.get_constraintsystem(sys)
4343

4444
if !isnothing(constraintsys)
45-
(length(constraints(constraintsys)) + length(u0map) > length(sts)) &&
45+
(length(constraints(constraintsys)) + length(u0map) > length(states)) &&
4646
@warn "The JuMPControlProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The solvers will default to doing a nonlinear least-squares optimization."
4747
end
4848

49-
f, u0, p = MTK.process_SciMLProblem(ODEFunction, sys, u0map, pmap;
49+
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
50+
f, u0, p = MTK.process_SciMLProblem(ODEFunction, sys, _u0map, pmap;
5051
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
5152

5253
model = InfiniteModel()
@@ -67,7 +68,7 @@ end
6768
function add_jump_cost_function!(model, sys)
6869
jcosts = MTK.get_costs(sys)
6970
consolidate = MTK.get_consolidate(sys)
70-
if isnothing(consolidate)
71+
if isnothing(jcosts)
7172
@objective(model, Min, 0)
7273
return
7374
end
@@ -81,55 +82,52 @@ function add_jump_cost_function!(model, sys)
8182
t = only(arguments(st))
8283
idx = stidxmap[x(iv)]
8384
subval = isequal(t, iv) ? model[:U][idx] : model[:U][idx](t)
84-
jcosts = Symbolics.substitute(jcosts, Dict(x(t) => subval))
85+
jcosts = map(c -> Symbolics.substitute(c, Dict(x(t) => subval)), jcosts)
8586
end
8687

8788
for ct in controls(sys)
8889
p = operation(ct)
8990
t = only(arguments(ct))
9091
idx = cidxmap[p(iv)]
9192
subval = isequal(t, iv) ? model[:V][idx] : model[:V][idx](t)
92-
jcosts = Symbolics.substitute(jcosts, Dict(x(t) => subval))
93+
jcosts = map(c -> Symbolics.substitute(c, Dict(p(t) => subval)), jcosts)
9394
end
9495

9596
@objective(model, Min, consolidate(jcosts))
9697
end
9798

9899
function add_user_constraints!(model, sys)
99-
jconstraints = if !(csys = MTK.get_constraintsystem(sys) isa Nothing)
100-
MTK.get_constraints(csys)
101-
else
102-
nothing
103-
end
104-
isnothing(jconstraints) && return nothing
100+
conssys = MTK.get_constraintsystem(sys)
101+
jconstraints = isnothing(conssys) ? nothing : MTK.get_constraints(conssys)
102+
(isnothing(jconstraints) || isempty(jconstraints)) && return nothing
105103

106104
iv = MTK.get_iv(sys)
107105
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
108106
cidxmap = Dict([v => i for (i, v) in enumerate(controls(sys))])
109107

110-
for st in unknowns(sys)
108+
for st in unknowns(conssys)
111109
x = operation(st)
112110
t = only(arguments(st))
113111
idx = stidxmap[x(iv)]
114112
subval = isequal(t, iv) ? model[:U][idx] : model[:U][idx](t)
115-
jconstraints = Symbolics.substitute(jconstraints, Dict(x(t) => subval))
113+
jconstraints = map(c -> Symbolics.substitute(c, Dict(x(t) => subval)), jconstraints)
116114
end
117115

118116
for ct in controls(sys)
119117
p = operation(ct)
120118
t = only(arguments(ct))
121119
idx = cidxmap[p(iv)]
122120
subval = isequal(t, iv) ? model[:V][idx] : model[:V][idx](t)
123-
jconstraints = Symbolics.substitute(jconstraints, Dict(p(t) => subval))
121+
jconstraints = map(c -> Symbolics.substitute(jconstraints, Dict(p(t) => subval)), jconstriants)
124122
end
125123

126124
for (i, cons) in enumerate(jconstraints)
127125
if cons isa Equation
128-
@constraint(model, user[i], cons.lhs - cons.rhs == 0)
126+
@constraint(model, cons.lhs - cons.rhs == 0, base_name = "user[$i]")
129127
elseif cons.relational_op === Symbolics.geq
130-
@constraint(model, user[i], cons.lhs - cons.rhs 0)
128+
@constraint(model, cons.lhs - cons.rhs 0, base_name = "user[$i]")
131129
else
132-
@constraint(model, user[i], cons.lhs - cons.rhs 0)
130+
@constraint(model, cons.lhs - cons.rhs 0, base_name = "user[$i]")
133131
end
134132
end
135133
end
@@ -189,6 +187,7 @@ end
189187
struct JuMPControlSolution
190188
model::InfiniteModel
191189
sol::ODESolution
190+
input_sol::Union{Nothing, ODESolution}
192191
end
193192

194193
"""
@@ -213,7 +212,6 @@ function DiffEqBase.solve(prob::JuMPControlProblem, jump_solver, ode_solver::Sym
213212
end
214213
end
215214
for var in all_variables(model)
216-
@show JuMP.name(var)
217215
if occursin("K", JuMP.name(var))
218216
unregister(model, Symbol(JuMP.name(var)))
219217
delete(model, var)
@@ -232,10 +230,19 @@ function DiffEqBase.solve(prob::JuMPControlProblem, jump_solver, ode_solver::Sym
232230
U_vals = [[U_vals[i][j] for i in 1:length(U_vals)] for j in 1:length(ts)]
233231
sol = DiffEqBase.build_solution(prob, ode_solver, ts, U_vals)
234232

233+
input_sol = nothing
234+
if !isempty(model[:V])
235+
V_vals = value.(model[:V])
236+
V_vals = [[V_vals[i][j] for i in 1:length(V_vals)] for j in 1:length(ts)]
237+
input_sol = DiffEqBase.build_solution(prob, ode_solver, ts, V_vals)
238+
end
239+
235240
if !(pstatus === FEASIBLE_POINT && (tstatus === OPTIMAL || tstatus === LOCALLY_SOLVED || tstatus === ALMOST_OPTIMAL || tstatus === ALMOST_LOCALLY_SOLVED))
236241
sol = SciMLBase.solution_new_retcode(sol, SciMLBase.ReturnCode.ConvergenceFailure)
242+
!isnothing(input_sol) && (input_sol = SciMLBase.solution_new_retcode(input_sol, SciMLBase.ReturnCode.ConvergenceFailure))
237243
end
238-
JuMPControlSolution(model, sol)
244+
245+
JuMPControlSolution(model, sol, input_sol)
239246
end
240247

241248
end

src/systems/diffeqs/odesystem.jl

+2
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
336336

337337
if length(costs) > 1 && isnothing(consolidate)
338338
error("Must specify a consolidation function for the costs vector.")
339+
elseif isnothing(consolidate)
340+
consolidate(u) = u[1]
339341
end
340342

341343
assertions = Dict{BasicSymbolic, Any}(unwrap(k) => v for (k, v) in assertions)

test/extensions/jump_control.jl

+26-10
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ const M = ModelingToolkit
1717

1818
tspan = (0.0, 1.0)
1919
u0map = [x(t) => 4.0, y(t) => 2.0]
20-
parammap ==> 7.5, β => 4, γ => 8.0, δ => 5.0]
20+
parammap ==> 1.5, β => 1.0, γ => 3.0, δ => 1.0]
2121
@mtkbuild sys = ODESystem(eqs, t)
2222

2323
# Test explicit method.
@@ -39,17 +39,33 @@ const M = ModelingToolkit
3939
constr = [x(0.6) ~ 3.5, x(0.3) ~ 7.0]
4040
@mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
4141

42-
jprob = JuMPControlProblem(sys, u0map, tspan, parammap; guesses, dt = 0.01)
43-
@test num_constraints(jprob.model) == 2 == num_variables(jprob.model) == 2
44-
jsol = solve(prob, Ipopt.Optimizer, :Tsitouras5)
42+
jprob = JuMPControlProblem(lksys, u0map, tspan, parammap; guesses = guess, dt = 0.01)
43+
@test num_constraints(jprob.model) == 2
44+
jsol = solve(jprob, Ipopt.Optimizer, :Tsitouras5)
4545
sol = jsol.sol
4646
@test sol(0.6)[1] 3.5
4747
@test sol(0.3)[1] 7.0
4848
end
4949

50-
@testset "Optimal control problem" begin
51-
# Investing
52-
53-
54-
# Bang-bang control
55-
end
50+
#@testset "Optimal control: bees" begin
51+
# # Example from Lawrence Evans' notes
52+
# M.@variables w(..) q(..)
53+
# M.@parameters α(t) [bounds = [0, 1]] b c μ s ν
54+
#
55+
# tspan = (0, 4)
56+
# eqs = [D(w(t)) ~ -μ*w(t) + b*s*α*w(t),
57+
# D(q(t)) ~ -ν*q(t) + c*(1 - α)*s*w(t)]
58+
# costs = [-q(tspan[2])]
59+
#
60+
# @mtkbuild beesys = ODESystem(eqs, t; costs)
61+
# u0map = [w(0) => 40, q(0) => 2]
62+
# pmap = [b => 1, c => 1, μ => 1, s => 1, ν => 1]
63+
#
64+
# jprob = JuMPControlProblem(beesys, u0map, tspan, pmap)
65+
# jsol = solve(jprob, Ipopt.Optimizer, :Tsitouras5)
66+
# control_sol = jsol.control_sol
67+
# # Bang-bang control
68+
#end
69+
#
70+
#@testset "Constrained optimal control problems" begin
71+
#end

0 commit comments

Comments
 (0)