|
| 1 | +module MTKJuMPControlExt |
| 2 | +using ModelingToolkit |
| 3 | +using JuMP, InfiniteOpt |
| 4 | +using DiffEqDevTools, DiffEqBase, SciMLBase |
| 5 | +using LinearAlgebra |
| 6 | +const MTK = ModelingToolkit |
| 7 | + |
| 8 | +struct JuMPControlProblem{uType, tType, P, F, K} |
| 9 | + f::F |
| 10 | + u0::uType |
| 11 | + tspan::tType |
| 12 | + p::P |
| 13 | + model::InfiniteModel |
| 14 | + kwargs::K |
| 15 | + |
| 16 | + function JuMPControlProblem(f, u0, tspan, p, model; kwargs...) |
| 17 | + new{typeof(u0), typeof(tspan), typeof(p), typeof(f), typeof(kwargs)}(f, u0, tspan, p, model, kwargs) |
| 18 | + end |
| 19 | +end |
| 20 | + |
| 21 | +""" |
| 22 | + JuMPControlProblem(sys::ODESystem, u0, tspan, p; dt) |
| 23 | +
|
| 24 | +Convert an ODESystem representing an optimal control system into a JuMP model |
| 25 | +for solving using optimization. Must provide `dt` for determining the length |
| 26 | +of the interpolation arrays. |
| 27 | +
|
| 28 | +The optimization variables: |
| 29 | +- a vector-of-vectors U representing the unknowns as an interpolation array |
| 30 | +- a vector-of-vectors V representing the controls as an interpolation array |
| 31 | +
|
| 32 | +The constraints are: |
| 33 | +- The set of user constraints passed to the ODESystem via `constraints` |
| 34 | +- The solver constraints that encode the time-stepping used by the solver |
| 35 | +""" |
| 36 | +function MTK.JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("dt must be provided for JuMPControlProblem."), kwargs...) |
| 37 | + ts = tspan[1] |
| 38 | + te = tspan[2] |
| 39 | + steps = ts:dt:te |
| 40 | + ctrls = controls(sys) |
| 41 | + states = unknowns(sys) |
| 42 | + constraintsys = MTK.get_constraintsystem(sys) |
| 43 | + |
| 44 | + if !isnothing(constraintsys) |
| 45 | + (length(constraints(constraintsys)) + length(u0map) > length(sts)) && |
| 46 | + @warn "The JuMPControlProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The solvers will default to doing a nonlinear least-squares optimization." |
| 47 | + end |
| 48 | + |
| 49 | + f, u0, p = MTK.process_SciMLProblem(ODEFunction, sys, u0map, pmap; |
| 50 | + t = tspan !== nothing ? tspan[1] : tspan, kwargs...) |
| 51 | + |
| 52 | + model = InfiniteModel() |
| 53 | + @infinite_parameter(model, t in [ts, te], num_supports = length(steps), derivative_method = OrthogonalCollocation(2)) |
| 54 | + @variable(model, U[1:length(states)], Infinite(t), start = ts) |
| 55 | + @variable(model, V[1:length(ctrls)], Infinite(t), start = ts) |
| 56 | + |
| 57 | + add_jump_cost_function!(model, sys) |
| 58 | + add_user_constraints!(model, sys) |
| 59 | + |
| 60 | + stidxmap = Dict([v => i for (i, v) in enumerate(states)]) |
| 61 | + u0_idxs = has_alg_eqs(sys) ? collect(1:length(states)) : [stidxmap[k] for (k, v) in u0map] |
| 62 | + add_initial_constraints!(model, u0, u0_idxs, tspan) |
| 63 | + |
| 64 | + JuMPControlProblem(f, u0, tspan, p, model, kwargs...) |
| 65 | +end |
| 66 | + |
| 67 | +function add_jump_cost_function!(model, sys) |
| 68 | + jcosts = MTK.get_costs(sys) |
| 69 | + consolidate = MTK.get_consolidate(sys) |
| 70 | + if isnothing(consolidate) |
| 71 | + @objective(model, Min, 0) |
| 72 | + return |
| 73 | + end |
| 74 | + iv = MTK.get_iv(sys) |
| 75 | + |
| 76 | + stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))]) |
| 77 | + cidxmap = Dict([v => i for (i, v) in enumerate(controls(sys))]) |
| 78 | + |
| 79 | + for st in unknowns(sys) |
| 80 | + x = operation(st) |
| 81 | + t = only(arguments(st)) |
| 82 | + idx = stidxmap[x(iv)] |
| 83 | + subval = isequal(t, iv) ? model[:U][idx] : model[:U][idx](t) |
| 84 | + jcosts = Symbolics.substitute(jcosts, Dict(x(t) => subval)) |
| 85 | + end |
| 86 | + |
| 87 | + for ct in controls(sys) |
| 88 | + p = operation(ct) |
| 89 | + t = only(arguments(ct)) |
| 90 | + idx = cidxmap[p(iv)] |
| 91 | + subval = isequal(t, iv) ? model[:V][idx] : model[:V][idx](t) |
| 92 | + jcosts = Symbolics.substitute(jcosts, Dict(x(t) => subval)) |
| 93 | + end |
| 94 | + |
| 95 | + @objective(model, Min, consolidate(jcosts)) |
| 96 | +end |
| 97 | + |
| 98 | +function add_user_constraints!(model, sys) |
| 99 | + jconstraints = if !(csys = MTK.get_constraintsystem(sys) isa Nothing) |
| 100 | + MTK.get_constraints(csys) |
| 101 | + else |
| 102 | + nothing |
| 103 | + end |
| 104 | + isnothing(jconstraints) && return nothing |
| 105 | + |
| 106 | + iv = MTK.get_iv(sys) |
| 107 | + stidxmap = Dict([v => i for (i, v) in enumerate(unknowns(sys))]) |
| 108 | + cidxmap = Dict([v => i for (i, v) in enumerate(controls(sys))]) |
| 109 | + |
| 110 | + for st in unknowns(sys) |
| 111 | + x = operation(st) |
| 112 | + t = only(arguments(st)) |
| 113 | + idx = stidxmap[x(iv)] |
| 114 | + subval = isequal(t, iv) ? model[:U][idx] : model[:U][idx](t) |
| 115 | + jconstraints = Symbolics.substitute(jconstraints, Dict(x(t) => subval)) |
| 116 | + end |
| 117 | + |
| 118 | + for ct in controls(sys) |
| 119 | + p = operation(ct) |
| 120 | + t = only(arguments(ct)) |
| 121 | + idx = cidxmap[p(iv)] |
| 122 | + subval = isequal(t, iv) ? model[:V][idx] : model[:V][idx](t) |
| 123 | + jconstraints = Symbolics.substitute(jconstraints, Dict(p(t) => subval)) |
| 124 | + end |
| 125 | + |
| 126 | + for (i, cons) in enumerate(jconstraints) |
| 127 | + if cons isa Equation |
| 128 | + @constraint(model, user[i], cons.lhs - cons.rhs == 0) |
| 129 | + elseif cons.relational_op === Symbolics.geq |
| 130 | + @constraint(model, user[i], cons.lhs - cons.rhs ≥ 0) |
| 131 | + else |
| 132 | + @constraint(model, user[i], cons.lhs - cons.rhs ≤ 0) |
| 133 | + end |
| 134 | + end |
| 135 | +end |
| 136 | + |
| 137 | +function add_initial_constraints!(model, u0, u0_idxs, tspan) |
| 138 | + ts = tspan[1] |
| 139 | + @constraint(model, init_u0_idx[i in u0_idxs], model[:U][i](ts) == u0[i]) |
| 140 | +end |
| 141 | + |
| 142 | +is_explicit(tableau) = tableau isa DiffEqDevTools.ExplicitRKTableau |
| 143 | + |
| 144 | +function add_solve_constraints!(prob, tableau) |
| 145 | + A = tableau.A |
| 146 | + α = tableau.α |
| 147 | + c = tableau.c |
| 148 | + model = prob.model |
| 149 | + f = prob.f |
| 150 | + p = prob.p |
| 151 | + tsteps = supports(model[:t]) |
| 152 | + pop!(tsteps) |
| 153 | + dt = tsteps[2] - tsteps[1] |
| 154 | + |
| 155 | + U = model[:U] |
| 156 | + nᵤ = length(U) |
| 157 | + if is_explicit(tableau) |
| 158 | + K = Any[] |
| 159 | + for τ in tsteps |
| 160 | + for (i, h) in enumerate(c) |
| 161 | + ΔU = sum([A[i, j] * K[j] for j in 1:i-1], init = zeros(nᵤ)) |
| 162 | + Uₙ = [U[i](τ) + ΔU[i]*dt for i in 1:nᵤ] |
| 163 | + Kₙ = f(Uₙ, p, τ + h*dt) |
| 164 | + push!(K, Kₙ) |
| 165 | + end |
| 166 | + ΔU = sum([α[i] * K[i] for i in 1:length(α)]) |
| 167 | + @constraint(model, [n = 1:nᵤ], U[n](τ) + ΔU[n] == U[n](τ + dt)) |
| 168 | + empty!(K) |
| 169 | + end |
| 170 | + else |
| 171 | + @variable(model, K[1:length(a), 1:nᵤ], Infinite(t), start = tsteps[1]) |
| 172 | + for τ in tsteps |
| 173 | + ΔUs = [A * K(τ)] |
| 174 | + for (i, h) in enumerate(c) |
| 175 | + ΔU = ΔUs[i] |
| 176 | + Uₙ = [U[j](τ) + ΔU[j](τ)*dt for j in 1:nᵤ] |
| 177 | + @constraint(model, K[i](τ) == f(Uₙ, p, τ + h*dt)) |
| 178 | + end |
| 179 | + ΔU = sum([α[i] * K[i] for i in 1:length(α)]) |
| 180 | + @constraint(model, U(τ) + dot(α, K(τ)) == U(τ + dt)) |
| 181 | + end |
| 182 | + end |
| 183 | +end |
| 184 | + |
| 185 | +""" |
| 186 | +""" |
| 187 | +struct JuMPControlSolution |
| 188 | + model::InfiniteModel |
| 189 | + sol::ODESolution |
| 190 | +end |
| 191 | + |
| 192 | +""" |
| 193 | +Solve JuMPControlProblem. Arguments: |
| 194 | +- prob: a JumpControlProblem |
| 195 | +- jump_solver: a LP solver such as HiGHS |
| 196 | +- ode_solver: Takes in a symbol representing the solver. Acceptable solvers may be found at https://docs.sciml.ai/DiffEqDevDocs/stable/internals/tableaus/. Note that the symbol may be different than the typical name of the solver, e.g. :Tsitouras5 rather than Tsit5. |
| 197 | +""" |
| 198 | +function DiffEqBase.solve(prob::JuMPControlProblem, jump_solver, ode_solver::Symbol) |
| 199 | + model = prob.model |
| 200 | + tableau_getter = Symbol(:construct, ode_solver) |
| 201 | + @eval tableau = $tableau_getter() |
| 202 | + ts = supports(model[:t]) |
| 203 | + add_solve_constraints!(prob, tableau) |
| 204 | + |
| 205 | + set_optimizer(model, jump_solver) |
| 206 | + optimize!(model) |
| 207 | + |
| 208 | + if is_solved_and_feasible(model) |
| 209 | + sol = DiffEqBase.build_solution(prob, ode_solver, ts, value.(U)) |
| 210 | + JuMPControlSolution(model, sol) |
| 211 | + else |
| 212 | + sol = DiffEqBase.build_solution(prob, ode_solver, ts, value.(U)) |
| 213 | + sol = SciMLBase.solution_new_retcode(sol, SciMLBase.ReturnCode.ConvergenceFailure) |
| 214 | + JuMPControlSolution(model, sol) |
| 215 | + end |
| 216 | +end |
| 217 | + |
| 218 | +end |
0 commit comments