Skip to content

Commit 8e6b444

Browse files
committed
up
1 parent a567ced commit 8e6b444

File tree

1 file changed

+28
-21
lines changed

1 file changed

+28
-21
lines changed

ext/MTKJuMPExt.jl

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ struct JuMPControlProblem{uType, tType, isinplace, P, F, K} <:
88
f::F
99
u0::uType
1010
tspan::tType
11-
p
12-
model
13-
kwargs
11+
p::P
12+
model::Model
13+
kwargs::K
1414
end
1515

1616
"""
@@ -28,7 +28,7 @@ The constraints are:
2828
- The set of user constraints passed to the ODESystem via `constraints`
2929
- The solver constraints that encode the time-stepping used by the solver
3030
"""
31-
function JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("dt must be provided for JuMPProblem."), solver = :Tsit5)
31+
function JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("dt must be provided for JuMPControlProblem."), guesses, eval_expression, eval_module)
3232
ts = tspan[1]
3333
te = tspan[2]
3434
steps = ts:dt:te
@@ -37,7 +37,7 @@ function JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("dt m
3737

3838
if !isnothing(constraintsys)
3939
(length(constraints(constraintsys)) + length(u0map) > length(sts)) &&
40-
@warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization."
40+
@warn "The JuMPControlProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The solvers will default to doing a nonlinear least-squares optimization."
4141
end
4242

4343
model = InfiniteModel()
@@ -52,9 +52,12 @@ function JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("dt m
5252

5353
add_jump_cost_function!(model, sys)
5454
add_user_constraints!(model, sys)
55-
add_solve_constraints!(model)
5655

57-
JuMPControlProblem{iip}(f, u0, tspan, p, model; kwargs...)
56+
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
57+
u0_idxs = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k, v) in u0map]
58+
add_initial_constraints!(model, u0, u0_idxs, tspan)
59+
60+
JuMPControlProblem{iip}(f, u0, tspan, p, model, kwargs...)
5861
end
5962

6063
function add_jump_cost_function!(model, sys)
@@ -114,44 +117,48 @@ function add_user_constraints!(model, sys, u0map)
114117
@constraint(model, user[i], cons.lhs - cons.rhs 0)
115118
end
116119
end
120+
end
117121

118-
# Add initial constraints.
122+
function add_initial_constraints!(model, u0, u0_idxs, tspan)
123+
ts = tspan[1]
124+
@constraint(model, init_u0_idx[i in u0_idxs], U[i](ts) == u0[i])
119125
end
120126

121-
function add_solve_constraints!(prob, talbeau, f, tsteps)
127+
is_explicit(tableau) = tableau isa DiffEqDevTools.ExplicitRKTableau
128+
129+
function add_solve_constraints!(prob, tableau, f, tsteps)
122130
A = tableau.A
123131
α = tableau.α
124132
c = tableau.c
125133
model = prob.model
126134
p = prob.p
127135
dt = step(tsteps)
128136

137+
U = model[:U]
129138
if is_explicit(tableau)
130139
K = Any[]
131-
for t in tsteps
140+
for τ in tsteps
132141
for (i, h) in enumerate(c)
133142
ΔU = sum([A[i, j] * K[j] for j in 1:i-1])
134-
Kₙ = f(U + ΔU*dt, p, t + h*dt)
143+
Kₙ = f(U + ΔU*dt, p, τ + h*dt)
135144
push!(K, Kₙ)
136145
end
137-
@constraint(model, U(t) + dot(α, K) == U(t + dt))
146+
@constraint(model, U(τ) + dot(α, K) == U(τ + dt))
138147
empty!(K)
139148
end
140149
else
141150
@variable(model, K[1:length(a)], Infinite(t), start = tsteps[1])
142-
for t in tsteps
143-
ΔUs = A * K(t)
151+
for τ in tsteps
152+
ΔUs = A * K(τ)
144153
for (i, h) in enumerate(c)
145154
ΔU = ΔUs[i]
146-
@constraint(model, K[i](t) == f(U + ΔU*dt, p, t + h*dt))
155+
@constraint(model, K[i](τ) == f(U(τ) + ΔU*dt, p, τ + h*dt))
147156
end
148-
@constraint(model, U(t) + dot(α, K(t)) == U(t + dt))
157+
@constraint(model, U(τ) + dot(α, K(τ)) == U(τ + dt))
149158
end
150159
end
151160
end
152161

153-
is_explicit(tableau) = tableau isa DiffEqDevTools.ExplicitRKTableau
154-
155162
"""
156163
"""
157164
struct JuMPControlSolution
@@ -167,16 +174,16 @@ name of the solver, e.g. :Tsitouras5 rather than Tsit5.
167174
function solve(prob::JuMPProblem, jump_solver, ode_solver::Symbol)
168175
model = prob.model
169176
f = prob.f
170-
tableau_getter = Symbol(:construct, solver)
177+
tableau_getter = Symbol(:construct, ode_solver)
171178
@eval tableau = $tableau_getter()
172179
ts = prob.tspan[1]:dt:prob.tspan[2]
173180
add_solve_constraints!(model, ts, tableau, f)
174181

175-
set_optimizer(model, solver)
182+
set_optimizer(model, jump_solver)
176183
optimize!(model)
177184

178185
if is_solved_and_feasible(model)
179-
sol = DiffEqBase.build_solution(prob, ode_solver, ts, value(U))
186+
sol = DiffEqBase.build_solution(prob, ode_solver, ts, value.(U))
180187
JuMPControlSolution(model, sol)
181188
end
182189
end

0 commit comments

Comments
 (0)