Skip to content

Commit 019db8b

Browse files
committed
add solver getter
1 parent 8b19eab commit 019db8b

File tree

1 file changed

+117
-18
lines changed

1 file changed

+117
-18
lines changed

ext/MTKJuMPExt.jl

Lines changed: 117 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,137 @@
1+
module MTKJuMPControlExt
2+
using ModelingToolkit
13
using JuMP, InfiniteOpt
4+
using DiffEqDevTools, DiffEqBase
5+
6+
struct JuMPProblem{uType, tType, isinplace, P, F, K} <:
7+
AbstractODEProblem{uType, tType, isinplace}
8+
f::F
9+
u0::uType
10+
tspan
11+
p
12+
model
13+
kwargs
14+
end
215

316
"""
4-
Convert an ODESystem with constraints to a JuMPProblem for optimal control solving.
17+
JuMPProblem(sys::ODESystem, u0, tspan, p; dt)
18+
19+
Convert an ODESystem representing an optimal control system into a JuMP model
20+
for solving using optimization. Must provide `dt` for determining the length
21+
of the interpolation arrays.
22+
23+
The optimization variables:
24+
- a vector-of-vectors U representing the unknowns as an interpolation array
25+
- a vector-of-vectors V representing the controls as an interpolation array
26+
27+
The constraints are:
28+
- The set of user constraints passed to the ODESystem via `constraints`
29+
- The solver constraints that encode the time-stepping used by the solver
530
"""
6-
function JuMPProblem(sys::ODESystem, u0, tspan, p; dt = error("dt must be provided for JuMPProblem."))
31+
function JuMPProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("dt must be provided for JuMPProblem."), solver = :Tsit5)
732
ts = tspan[1]
833
te = tspan[2]
934
steps = ts:dt:te
10-
costs = get_costs(sys)
11-
consolidate = get_consolidate(sys)
1235
ctrls = get_ctrls(sys)
1336
states = unknowns(sys)
14-
constraints = get_constraints(get_constraintsystem(sys))
1537

16-
model = Model()
38+
if !isnothing(constraintsys)
39+
(length(constraints(constraintsys)) + length(u0map) > length(sts)) &&
40+
@warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization."
41+
end
1742

18-
@infinite_parameter(model, t in [tspan[1],tspan[2]], num_supports = length(steps), derivative_method = OrthogonalCollocation(2))
19-
@variables(model, U[1:length(states)], Infinite(t), start = ts)
20-
@variables(model, V[1:length(ctrls)], Infinite(t), start = ts)
21-
@variables(model, K)
43+
model = InfiniteModel()
44+
@infinite_parameter(model, t in [ts, te], num_supports = length(steps), derivative_method = OrthogonalCollocation(2))
45+
@variable(model, U[1:length(states)], Infinite(t), start = ts)
46+
@variable(model, V[1:length(ctrls)], Infinite(t), start = ts)
47+
@variable(model, K)
48+
49+
f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, _u0map, parammap;
50+
t = tspan !== nothing ? tspan[1] : tspan, guesses,
51+
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
52+
53+
add_jump_cost_function!(model, sys)
54+
add_user_constraints!(model, sys)
55+
add_solve_constraints!(model)
56+
57+
JuMPProblem{iip}(f, u0, tspan, p, model; kwargs...)
58+
end
59+
60+
function add_jump_cost_function!(model, sys)
61+
jcosts = get_costs(sys)
62+
consolidate = get_consolidate(sys)
63+
iv = get_iv(sys)
2264

23-
jcost = generate_jump_cost_function(sys)
24-
@objective
65+
stidxmap = Dict([v => i for (i, v) in enumerate(get_unknowns(sys))])
66+
cidxmap = Dict([v => i for (i, v) in enumerate(get_ctrls(sys))])
2567

26-
constraints = generate_jump_constraints(constraints)
27-
@constraints
68+
for st in get_unknowns(sys)
69+
x = operation(st)
70+
t = only(arguments(st))
71+
idx = stidxmap[x(iv)]
72+
jcosts = Symbolics.substitute(costs, Dict(x(t) => model[:U][idx](t)))
73+
end
74+
75+
for ct in get_ctrls(sys)
76+
p = operation(ct)
77+
t = only(arguments(ct))
78+
idx = cidxmap[p(iv)]
79+
jcosts = Symbolics.substitute(costs, Dict(p(t) => model[:V][idx](t)))
80+
end
81+
82+
@objective(model, Min, consolidate(jcosts))
2883
end
2984

30-
function generate_jump_cost_function(costs, tsteps)
85+
function add_user_constraints!(model, sys, u0map)
86+
jconstraints = get_constraints(get_constraintsystem(sys))
87+
iv = get_iv(sys)
88+
89+
stidxmap = Dict([v => i for (i, v) in enumerate(get_unknowns(sys))])
90+
cidxmap = Dict([v => i for (i, v) in enumerate(get_ctrls(sys))])
91+
92+
for st in get_unknowns(sys)
93+
x = operation(st)
94+
t = only(arguments(st))
95+
idx = stidxmap[x(iv)]
96+
subval = isequal(t, iv) ? model[:U][idx] : model[:U][idx](t)
97+
jconstraints = Symbolics.substitute(constraints, Dict(x(t) => subval))
98+
end
99+
100+
for ct in get_ctrls(sys)
101+
p = operation(ct)
102+
t = only(arguments(ct))
103+
idx = cidxmap[p(iv)]
104+
subval = isequal(t, iv) ? model[:V][idx] : model[:V][idx](t)
105+
jconstraints = Symbolics.substitute(constraints, Dict(p(t) => subval))
106+
end
107+
108+
for (i, cons) in enumerate(jconstraints)
109+
if cons isa Equation
110+
@constraint(model, user[i], cons.lhs - cons.rhs == 0)
111+
elseif cons.relational_op === Symbolics.geq
112+
@constraint(model, user[i], cons.lhs - cons.rhs 0)
113+
else
114+
@constraint(model, user[i], cons.lhs - cons.rhs 0)
115+
end
116+
end
117+
118+
# Add initial constraints.
31119
end
32120

33-
function generate_jump_constraints(constraints, jump_vars, jump_ps)
121+
function add_solve_constraints!(model, tsteps, solver)
122+
tableau = fetch_tableau(solver)
123+
124+
for (i, t) in collect(enumerate(tsteps))
125+
end
34126
end
35127

36-
function t_to_tstep()
37-
128+
"""
129+
Solve JuMPProblem. Takes in a symbol representing the solver.
130+
"""
131+
function solve(prob::JuMPProblem, solver_sym::Symbol)
132+
model = prob.model
133+
tableau_getter = Symbol(:construct, solver)
134+
@eval tableau = $tableau_getter()
135+
add_solve_constraints!(model, tableau)
136+
end
38137
end

0 commit comments

Comments
 (0)