|
| 1 | +module MTKJuMPControlExt |
| 2 | +using ModelingToolkit |
1 | 3 | using JuMP, InfiniteOpt
|
| 4 | +using DiffEqDevTools, DiffEqBase |
| 5 | + |
| 6 | +struct JuMPProblem{uType, tType, isinplace, P, F, K} <: |
| 7 | + AbstractODEProblem{uType, tType, isinplace} |
| 8 | + f::F |
| 9 | + u0::uType |
| 10 | + tspan |
| 11 | + p |
| 12 | + model |
| 13 | + kwargs |
| 14 | +end |
2 | 15 |
|
3 | 16 | """
|
4 |
| -Convert an ODESystem with constraints to a JuMPProblem for optimal control solving. |
| 17 | + JuMPProblem(sys::ODESystem, u0, tspan, p; dt) |
| 18 | +
|
| 19 | +Convert an ODESystem representing an optimal control system into a JuMP model |
| 20 | +for solving using optimization. Must provide `dt` for determining the length |
| 21 | +of the interpolation arrays. |
| 22 | +
|
| 23 | +The optimization variables: |
| 24 | +- a vector-of-vectors U representing the unknowns as an interpolation array |
| 25 | +- a vector-of-vectors V representing the controls as an interpolation array |
| 26 | +
|
| 27 | +The constraints are: |
| 28 | +- The set of user constraints passed to the ODESystem via `constraints` |
| 29 | +- The solver constraints that encode the time-stepping used by the solver |
5 | 30 | """
|
6 |
| -function JuMPProblem(sys::ODESystem, u0, tspan, p; dt = error("dt must be provided for JuMPProblem.")) |
| 31 | +function JuMPProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("dt must be provided for JuMPProblem."), solver = :Tsit5) |
7 | 32 | ts = tspan[1]
|
8 | 33 | te = tspan[2]
|
9 | 34 | steps = ts:dt:te
|
10 |
| - costs = get_costs(sys) |
11 |
| - consolidate = get_consolidate(sys) |
12 | 35 | ctrls = get_ctrls(sys)
|
13 | 36 | states = unknowns(sys)
|
14 |
| - constraints = get_constraints(get_constraintsystem(sys)) |
15 | 37 |
|
16 |
| - model = Model() |
| 38 | + if !isnothing(constraintsys) |
| 39 | + (length(constraints(constraintsys)) + length(u0map) > length(sts)) && |
| 40 | + @warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization." |
| 41 | + end |
17 | 42 |
|
18 |
| - @infinite_parameter(model, t in [tspan[1],tspan[2]], num_supports = length(steps), derivative_method = OrthogonalCollocation(2)) |
19 |
| - @variables(model, U[1:length(states)], Infinite(t), start = ts) |
20 |
| - @variables(model, V[1:length(ctrls)], Infinite(t), start = ts) |
21 |
| - @variables(model, K) |
| 43 | + model = InfiniteModel() |
| 44 | + @infinite_parameter(model, t in [ts, te], num_supports = length(steps), derivative_method = OrthogonalCollocation(2)) |
| 45 | + @variable(model, U[1:length(states)], Infinite(t), start = ts) |
| 46 | + @variable(model, V[1:length(ctrls)], Infinite(t), start = ts) |
| 47 | + @variable(model, K) |
| 48 | + |
| 49 | + f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, _u0map, parammap; |
| 50 | + t = tspan !== nothing ? tspan[1] : tspan, guesses, |
| 51 | + check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...) |
| 52 | + |
| 53 | + add_jump_cost_function!(model, sys) |
| 54 | + add_user_constraints!(model, sys) |
| 55 | + add_solve_constraints!(model) |
| 56 | + |
| 57 | + JuMPProblem{iip}(f, u0, tspan, p, model; kwargs...) |
| 58 | +end |
| 59 | + |
| 60 | +function add_jump_cost_function!(model, sys) |
| 61 | + jcosts = get_costs(sys) |
| 62 | + consolidate = get_consolidate(sys) |
| 63 | + iv = get_iv(sys) |
22 | 64 |
|
23 |
| - jcost = generate_jump_cost_function(sys) |
24 |
| - @objective |
| 65 | + stidxmap = Dict([v => i for (i, v) in enumerate(get_unknowns(sys))]) |
| 66 | + cidxmap = Dict([v => i for (i, v) in enumerate(get_ctrls(sys))]) |
25 | 67 |
|
26 |
| - constraints = generate_jump_constraints(constraints) |
27 |
| - @constraints |
| 68 | + for st in get_unknowns(sys) |
| 69 | + x = operation(st) |
| 70 | + t = only(arguments(st)) |
| 71 | + idx = stidxmap[x(iv)] |
| 72 | + jcosts = Symbolics.substitute(costs, Dict(x(t) => model[:U][idx](t))) |
| 73 | + end |
| 74 | + |
| 75 | + for ct in get_ctrls(sys) |
| 76 | + p = operation(ct) |
| 77 | + t = only(arguments(ct)) |
| 78 | + idx = cidxmap[p(iv)] |
| 79 | + jcosts = Symbolics.substitute(costs, Dict(p(t) => model[:V][idx](t))) |
| 80 | + end |
| 81 | + |
| 82 | + @objective(model, Min, consolidate(jcosts)) |
28 | 83 | end
|
29 | 84 |
|
30 |
| -function generate_jump_cost_function(costs, tsteps) |
| 85 | +function add_user_constraints!(model, sys, u0map) |
| 86 | + jconstraints = get_constraints(get_constraintsystem(sys)) |
| 87 | + iv = get_iv(sys) |
| 88 | + |
| 89 | + stidxmap = Dict([v => i for (i, v) in enumerate(get_unknowns(sys))]) |
| 90 | + cidxmap = Dict([v => i for (i, v) in enumerate(get_ctrls(sys))]) |
| 91 | + |
| 92 | + for st in get_unknowns(sys) |
| 93 | + x = operation(st) |
| 94 | + t = only(arguments(st)) |
| 95 | + idx = stidxmap[x(iv)] |
| 96 | + subval = isequal(t, iv) ? model[:U][idx] : model[:U][idx](t) |
| 97 | + jconstraints = Symbolics.substitute(constraints, Dict(x(t) => subval)) |
| 98 | + end |
| 99 | + |
| 100 | + for ct in get_ctrls(sys) |
| 101 | + p = operation(ct) |
| 102 | + t = only(arguments(ct)) |
| 103 | + idx = cidxmap[p(iv)] |
| 104 | + subval = isequal(t, iv) ? model[:V][idx] : model[:V][idx](t) |
| 105 | + jconstraints = Symbolics.substitute(constraints, Dict(p(t) => subval)) |
| 106 | + end |
| 107 | + |
| 108 | + for (i, cons) in enumerate(jconstraints) |
| 109 | + if cons isa Equation |
| 110 | + @constraint(model, user[i], cons.lhs - cons.rhs == 0) |
| 111 | + elseif cons.relational_op === Symbolics.geq |
| 112 | + @constraint(model, user[i], cons.lhs - cons.rhs ≥ 0) |
| 113 | + else |
| 114 | + @constraint(model, user[i], cons.lhs - cons.rhs ≤ 0) |
| 115 | + end |
| 116 | + end |
| 117 | + |
| 118 | + # Add initial constraints. |
31 | 119 | end
|
32 | 120 |
|
33 |
| -function generate_jump_constraints(constraints, jump_vars, jump_ps) |
| 121 | +function add_solve_constraints!(model, tsteps, solver) |
| 122 | + tableau = fetch_tableau(solver) |
| 123 | + |
| 124 | + for (i, t) in collect(enumerate(tsteps)) |
| 125 | + end |
34 | 126 | end
|
35 | 127 |
|
36 |
| -function t_to_tstep() |
37 |
| - |
| 128 | +""" |
| 129 | +Solve JuMPProblem. Takes in a symbol representing the solver. |
| 130 | +""" |
| 131 | +function solve(prob::JuMPProblem, solver_sym::Symbol) |
| 132 | + model = prob.model |
| 133 | + tableau_getter = Symbol(:construct, solver) |
| 134 | + @eval tableau = $tableau_getter() |
| 135 | + add_solve_constraints!(model, tableau) |
| 136 | +end |
38 | 137 | end
|
0 commit comments