Skip to content

Commit 1133e4e

Browse files
committed
feat: solver tableau debugging
1 parent 8e6b444 commit 1133e4e

File tree

7 files changed

+289
-195
lines changed

7 files changed

+289
-195
lines changed

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
3131
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
3232
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
3333
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
34+
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
3435
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
3536
JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5"
3637
Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316"
@@ -66,8 +67,10 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
6667
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
6768
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
6869
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
70+
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
6971
FMI = "14a09403-18e3-468f-ad8a-74f8dda2d9ac"
7072
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
73+
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
7174
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
7275

7376
[extensions]
@@ -76,7 +79,7 @@ MTKChainRulesCoreExt = "ChainRulesCore"
7679
MTKDeepDiffsExt = "DeepDiffs"
7780
MTKFMIExt = "FMI"
7881
MTKInfiniteOptExt = "InfiniteOpt"
79-
MTKJuMP = "JuMP"
82+
MTKJuMPControlExt = ["JuMP", "DiffEqDevTools"]
8083
MTKLabelledArraysExt = "LabelledArrays"
8184

8285
[compat]
@@ -98,6 +101,7 @@ DeepDiffs = "1"
98101
DelayDiffEq = "5.50"
99102
DiffEqBase = "6.165.1"
100103
DiffEqCallbacks = "2.16, 3, 4"
104+
DiffEqDevTools = "2.48.0"
101105
DiffEqNoiseProcess = "5"
102106
DiffRules = "0.1, 1.0"
103107
DifferentiationInterface = "0.6.47"
@@ -116,6 +120,8 @@ FunctionWrappersWrappers = "0.1"
116120
Graphs = "1.5.2"
117121
InfiniteOpt = "0.5"
118122
InteractiveUtils = "1"
123+
Ipopt = "1.8.0"
124+
JuMP = "1.25.0"
119125
JuliaFormatter = "1.0.47"
120126
JumpProcesses = "9.13.1"
121127
LabelledArrays = "1.3"

ext/MTKJuMPControlExt.jl

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
module MTKJuMPControlExt
2+
using ModelingToolkit
3+
using JuMP, InfiniteOpt
4+
using DiffEqDevTools, DiffEqBase, SciMLBase
5+
using LinearAlgebra
6+
const MTK = ModelingToolkit
7+
8+
struct JuMPControlProblem{uType, tType, P, F, K}
9+
f::F
10+
u0::uType
11+
tspan::tType
12+
p::P
13+
model::InfiniteModel
14+
kwargs::K
15+
16+
function JuMPControlProblem(f, u0, tspan, p, model; kwargs...)
17+
new{typeof(u0), typeof(tspan), typeof(p), typeof(f), typeof(kwargs)}(f, u0, tspan, p, model, kwargs)
18+
end
19+
end
20+
21+
"""
22+
JuMPControlProblem(sys::ODESystem, u0, tspan, p; dt)
23+
24+
Convert an ODESystem representing an optimal control system into a JuMP model
25+
for solving using optimization. Must provide `dt` for determining the length
26+
of the interpolation arrays.
27+
28+
The optimization variables:
29+
- a vector-of-vectors U representing the unknowns as an interpolation array
30+
- a vector-of-vectors V representing the controls as an interpolation array
31+
32+
The constraints are:
33+
- The set of user constraints passed to the ODESystem via `constraints`
34+
- The solver constraints that encode the time-stepping used by the solver
35+
"""
36+
function MTK.JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("dt must be provided for JuMPControlProblem."), kwargs...)
37+
ts = tspan[1]
38+
te = tspan[2]
39+
steps = ts:dt:te
40+
ctrls = controls(sys)
41+
states = unknowns(sys)
42+
constraintsys = MTK.get_constraintsystem(sys)
43+
44+
if !isnothing(constraintsys)
45+
(length(constraints(constraintsys)) + length(u0map) > length(sts)) &&
46+
@warn "The JuMPControlProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The solvers will default to doing a nonlinear least-squares optimization."
47+
end
48+
49+
f, u0, p = MTK.process_SciMLProblem(ODEFunction, sys, u0map, pmap;
50+
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
51+
52+
model = InfiniteModel()
53+
@infinite_parameter(model, t in [ts, te], num_supports = length(steps), derivative_method = OrthogonalCollocation(2))
54+
@variable(model, U[1:length(states)], Infinite(t), start = ts)
55+
@variable(model, V[1:length(ctrls)], Infinite(t), start = ts)
56+
57+
add_jump_cost_function!(model, sys)
58+
add_user_constraints!(model, sys)
59+
60+
stidxmap = Dict([v => i for (i, v) in enumerate(states)])
61+
u0_idxs = has_alg_eqs(sys) ? collect(1:length(states)) : [stidxmap[k] for (k, v) in u0map]
62+
add_initial_constraints!(model, u0, u0_idxs, tspan)
63+
64+
JuMPControlProblem(f, u0, tspan, p, model, kwargs...)
65+
end
66+
67+
function add_jump_cost_function!(model, sys)
68+
jcosts = MTK.get_costs(sys)
69+
consolidate = MTK.get_consolidate(sys)
70+
if isnothing(consolidate)
71+
@objective(model, Min, 0)
72+
return
73+
end
74+
iv = MTK.get_iv(sys)
75+
76+
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
77+
cidxmap = Dict([v => i for (i, v) in enumerate(controls(sys))])
78+
79+
for st in unknowns(sys)
80+
x = operation(st)
81+
t = only(arguments(st))
82+
idx = stidxmap[x(iv)]
83+
subval = isequal(t, iv) ? model[:U][idx] : model[:U][idx](t)
84+
jcosts = Symbolics.substitute(jcosts, Dict(x(t) => subval))
85+
end
86+
87+
for ct in controls(sys)
88+
p = operation(ct)
89+
t = only(arguments(ct))
90+
idx = cidxmap[p(iv)]
91+
subval = isequal(t, iv) ? model[:V][idx] : model[:V][idx](t)
92+
jcosts = Symbolics.substitute(jcosts, Dict(x(t) => subval))
93+
end
94+
95+
@objective(model, Min, consolidate(jcosts))
96+
end
97+
98+
function add_user_constraints!(model, sys)
99+
jconstraints = if !(csys = MTK.get_constraintsystem(sys) isa Nothing)
100+
MTK.get_constraints(csys)
101+
else
102+
nothing
103+
end
104+
isnothing(jconstraints) && return nothing
105+
106+
iv = MTK.get_iv(sys)
107+
stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))])
108+
cidxmap = Dict([v => i for (i, v) in enumerate(controls(sys))])
109+
110+
for st in unknowns(sys)
111+
x = operation(st)
112+
t = only(arguments(st))
113+
idx = stidxmap[x(iv)]
114+
subval = isequal(t, iv) ? model[:U][idx] : model[:U][idx](t)
115+
jconstraints = Symbolics.substitute(jconstraints, Dict(x(t) => subval))
116+
end
117+
118+
for ct in controls(sys)
119+
p = operation(ct)
120+
t = only(arguments(ct))
121+
idx = cidxmap[p(iv)]
122+
subval = isequal(t, iv) ? model[:V][idx] : model[:V][idx](t)
123+
jconstraints = Symbolics.substitute(jconstraints, Dict(p(t) => subval))
124+
end
125+
126+
for (i, cons) in enumerate(jconstraints)
127+
if cons isa Equation
128+
@constraint(model, user[i], cons.lhs - cons.rhs == 0)
129+
elseif cons.relational_op === Symbolics.geq
130+
@constraint(model, user[i], cons.lhs - cons.rhs 0)
131+
else
132+
@constraint(model, user[i], cons.lhs - cons.rhs 0)
133+
end
134+
end
135+
end
136+
137+
function add_initial_constraints!(model, u0, u0_idxs, tspan)
138+
ts = tspan[1]
139+
@constraint(model, init_u0_idx[i in u0_idxs], model[:U][i](ts) == u0[i])
140+
end
141+
142+
is_explicit(tableau) = tableau isa DiffEqDevTools.ExplicitRKTableau
143+
144+
function add_solve_constraints!(prob, tableau)
145+
A = tableau.A
146+
α = tableau.α
147+
c = tableau.c
148+
model = prob.model
149+
f = prob.f
150+
p = prob.p
151+
tsteps = supports(model[:t])
152+
pop!(tsteps)
153+
dt = tsteps[2] - tsteps[1]
154+
155+
U = model[:U]
156+
nᵤ = length(U)
157+
if is_explicit(tableau)
158+
K = Any[]
159+
for τ in tsteps
160+
for (i, h) in enumerate(c)
161+
ΔU = sum([A[i, j] * K[j] for j in 1:i-1], init = zeros(nᵤ))
162+
Uₙ = [U[i](τ) + ΔU[i]*dt for i in 1:nᵤ]
163+
Kₙ = f(Uₙ, p, τ + h*dt)
164+
push!(K, Kₙ)
165+
end
166+
ΔU = sum([α[i] * K[i] for i in 1:length(α)])
167+
@constraint(model, [n = 1:nᵤ], U[n](τ) + ΔU[n] == U[n](τ + dt))
168+
empty!(K)
169+
end
170+
else
171+
@variable(model, K[1:length(a), 1:nᵤ], Infinite(t), start = tsteps[1])
172+
for τ in tsteps
173+
ΔUs = [A * K(τ)]
174+
for (i, h) in enumerate(c)
175+
ΔU = ΔUs[i]
176+
Uₙ = [U[j](τ) + ΔU[j](τ)*dt for j in 1:nᵤ]
177+
@constraint(model, K[i](τ) == f(Uₙ, p, τ + h*dt))
178+
end
179+
ΔU = sum([α[i] * K[i] for i in 1:length(α)])
180+
@constraint(model, U(τ) + dot(α, K(τ)) == U+ dt))
181+
end
182+
end
183+
end
184+
185+
"""
186+
"""
187+
struct JuMPControlSolution
188+
model::InfiniteModel
189+
sol::ODESolution
190+
end
191+
192+
"""
193+
Solve JuMPControlProblem. Arguments:
194+
- prob: a JumpControlProblem
195+
- jump_solver: a LP solver such as HiGHS
196+
- ode_solver: Takes in a symbol representing the solver. Acceptable solvers may be found at https://docs.sciml.ai/DiffEqDevDocs/stable/internals/tableaus/. Note that the symbol may be different than the typical name of the solver, e.g. :Tsitouras5 rather than Tsit5.
197+
"""
198+
function DiffEqBase.solve(prob::JuMPControlProblem, jump_solver, ode_solver::Symbol)
199+
model = prob.model
200+
tableau_getter = Symbol(:construct, ode_solver)
201+
@eval tableau = $tableau_getter()
202+
ts = supports(model[:t])
203+
add_solve_constraints!(prob, tableau)
204+
205+
set_optimizer(model, jump_solver)
206+
optimize!(model)
207+
208+
if is_solved_and_feasible(model)
209+
sol = DiffEqBase.build_solution(prob, ode_solver, ts, value.(U))
210+
JuMPControlSolution(model, sol)
211+
else
212+
sol = DiffEqBase.build_solution(prob, ode_solver, ts, value.(U))
213+
sol = SciMLBase.solution_new_retcode(sol, SciMLBase.ReturnCode.ConvergenceFailure)
214+
JuMPControlSolution(model, sol)
215+
end
216+
end
217+
218+
end

0 commit comments

Comments
 (0)