Skip to content

Commit a567ced

Browse files
committed
Implement solver tableau
1 parent 88ab0d3 commit a567ced

File tree

1 file changed

+59
-12
lines changed

1 file changed

+59
-12
lines changed

ext/MTKJuMPExt.jl

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,18 @@ using ModelingToolkit
33
using JuMP, InfiniteOpt
44
using DiffEqDevTools, DiffEqBase
55

6-
struct JuMPProblem{uType, tType, isinplace, P, F, K} <:
6+
struct JuMPControlProblem{uType, tType, isinplace, P, F, K} <:
77
AbstractODEProblem{uType, tType, isinplace}
88
f::F
99
u0::uType
10-
tspan
10+
tspan::tType
1111
p
1212
model
1313
kwargs
1414
end
1515

1616
"""
17-
JuMPProblem(sys::ODESystem, u0, tspan, p; dt)
17+
JuMPControlProblem(sys::ODESystem, u0, tspan, p; dt)
1818
1919
Convert an ODESystem representing an optimal control system into a JuMP model
2020
for solving using optimization. Must provide `dt` for determining the length
@@ -28,7 +28,7 @@ The constraints are:
2828
- The set of user constraints passed to the ODESystem via `constraints`
2929
- The solver constraints that encode the time-stepping used by the solver
3030
"""
31-
function JuMPProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("dt must be provided for JuMPProblem."), solver = :Tsit5)
31+
function JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("dt must be provided for JuMPProblem."), solver = :Tsit5)
3232
ts = tspan[1]
3333
te = tspan[2]
3434
steps = ts:dt:te
@@ -54,7 +54,7 @@ function JuMPProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("dt must be
5454
add_user_constraints!(model, sys)
5555
add_solve_constraints!(model)
5656

57-
JuMPProblem{iip}(f, u0, tspan, p, model; kwargs...)
57+
JuMPControlProblem{iip}(f, u0, tspan, p, model; kwargs...)
5858
end
5959

6060
function add_jump_cost_function!(model, sys)
@@ -118,20 +118,67 @@ function add_user_constraints!(model, sys, u0map)
118118
# Add initial constraints.
119119
end
120120

121-
function add_solve_constraints!(model, tsteps, solver)
122-
tableau = fetch_tableau(solver)
123-
124-
for (i, t) in collect(enumerate(tsteps))
121+
function add_solve_constraints!(prob, talbeau, f, tsteps)
122+
A = tableau.A
123+
α = tableau.α
124+
c = tableau.c
125+
model = prob.model
126+
p = prob.p
127+
dt = step(tsteps)
128+
129+
if is_explicit(tableau)
130+
K = Any[]
131+
for t in tsteps
132+
for (i, h) in enumerate(c)
133+
ΔU = sum([A[i, j] * K[j] for j in 1:i-1])
134+
Kₙ = f(U + ΔU*dt, p, t + h*dt)
135+
push!(K, Kₙ)
136+
end
137+
@constraint(model, U(t) + dot(α, K) == U(t + dt))
138+
empty!(K)
139+
end
140+
else
141+
@variable(model, K[1:length(a)], Infinite(t), start = tsteps[1])
142+
for t in tsteps
143+
ΔUs = A * K(t)
144+
for (i, h) in enumerate(c)
145+
ΔU = ΔUs[i]
146+
@constraint(model, K[i](t) == f(U + ΔU*dt, p, t + h*dt))
147+
end
148+
@constraint(model, U(t) + dot(α, K(t)) == U(t + dt))
149+
end
125150
end
126151
end
127152

153+
is_explicit(tableau) = tableau isa DiffEqDevTools.ExplicitRKTableau
154+
128155
"""
129-
Solve JuMPProblem. Takes in a symbol representing the solver.
130156
"""
131-
function solve(prob::JuMPProblem, solver_sym::Symbol)
157+
struct JuMPControlSolution
158+
model
159+
sol::ODESolution
160+
end
161+
162+
"""
163+
Solve JuMPProblem. Takes in a symbol representing the solver. Acceptable solvers may be found at https://docs.sciml.ai/DiffEqDevDocs/stable/internals/tableaus/.
164+
Note that the symbol may be different than the typical
165+
name of the solver, e.g. :Tsitouras5 rather than Tsit5.
166+
"""
167+
function solve(prob::JuMPProblem, jump_solver, ode_solver::Symbol)
132168
model = prob.model
169+
f = prob.f
133170
tableau_getter = Symbol(:construct, solver)
134171
@eval tableau = $tableau_getter()
135-
add_solve_constraints!(model, tableau)
172+
ts = prob.tspan[1]:dt:prob.tspan[2]
173+
add_solve_constraints!(model, ts, tableau, f)
174+
175+
set_optimizer(model, solver)
176+
optimize!(model)
177+
178+
if is_solved_and_feasible(model)
179+
sol = DiffEqBase.build_solution(prob, ode_solver, ts, value(U))
180+
JuMPControlSolution(model, sol)
181+
end
136182
end
183+
137184
end

0 commit comments

Comments
 (0)