@@ -33,7 +33,7 @@ The constraints are:
33
33
- The set of user constraints passed to the ODESystem via `constraints`
34
34
- The solver constraints that encode the time-stepping used by the solver
35
35
"""
36
- function MTK. JuMPControlProblem (sys:: ODESystem , u0map, tspan, pmap; dt = error (" dt must be provided for JuMPControlProblem." ), kwargs... )
36
+ function MTK. JuMPControlProblem (sys:: ODESystem , u0map, tspan, pmap; dt = error (" dt must be provided for JuMPControlProblem." ), guesses = Dict (), kwargs... )
37
37
ts = tspan[1 ]
38
38
te = tspan[2 ]
39
39
steps = ts: dt: te
@@ -42,11 +42,12 @@ function MTK.JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("
42
42
constraintsys = MTK. get_constraintsystem (sys)
43
43
44
44
if ! isnothing (constraintsys)
45
- (length (constraints (constraintsys)) + length (u0map) > length (sts )) &&
45
+ (length (constraints (constraintsys)) + length (u0map) > length (states )) &&
46
46
@warn " The JuMPControlProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The solvers will default to doing a nonlinear least-squares optimization."
47
47
end
48
48
49
- f, u0, p = MTK. process_SciMLProblem (ODEFunction, sys, u0map, pmap;
49
+ _u0map = has_alg_eqs (sys) ? u0map : merge (Dict (u0map), Dict (guesses))
50
+ f, u0, p = MTK. process_SciMLProblem (ODEFunction, sys, _u0map, pmap;
50
51
t = tspan != = nothing ? tspan[1 ] : tspan, kwargs... )
51
52
52
53
model = InfiniteModel ()
67
68
function add_jump_cost_function! (model, sys)
68
69
jcosts = MTK. get_costs (sys)
69
70
consolidate = MTK. get_consolidate (sys)
70
- if isnothing (consolidate )
71
+ if isnothing (jcosts )
71
72
@objective (model, Min, 0 )
72
73
return
73
74
end
@@ -81,55 +82,52 @@ function add_jump_cost_function!(model, sys)
81
82
t = only (arguments (st))
82
83
idx = stidxmap[x (iv)]
83
84
subval = isequal (t, iv) ? model[:U ][idx] : model[:U ][idx](t)
84
- jcosts = Symbolics. substitute (jcosts , Dict (x (t) => subval))
85
+ jcosts = map (c -> Symbolics. substitute (c , Dict (x (t) => subval)), jcosts )
85
86
end
86
87
87
88
for ct in controls (sys)
88
89
p = operation (ct)
89
90
t = only (arguments (ct))
90
91
idx = cidxmap[p (iv)]
91
92
subval = isequal (t, iv) ? model[:V ][idx] : model[:V ][idx](t)
92
- jcosts = Symbolics. substitute (jcosts , Dict (x (t) => subval))
93
+ jcosts = map (c -> Symbolics. substitute (c , Dict (p (t) => subval)), jcosts )
93
94
end
94
95
95
96
@objective (model, Min, consolidate (jcosts))
96
97
end
97
98
98
99
function add_user_constraints! (model, sys)
99
- jconstraints = if ! (csys = MTK. get_constraintsystem (sys) isa Nothing)
100
- MTK. get_constraints (csys)
101
- else
102
- nothing
103
- end
104
- isnothing (jconstraints) && return nothing
100
+ conssys = MTK. get_constraintsystem (sys)
101
+ jconstraints = isnothing (conssys) ? nothing : MTK. get_constraints (conssys)
102
+ (isnothing (jconstraints) || isempty (jconstraints)) && return nothing
105
103
106
104
iv = MTK. get_iv (sys)
107
105
stidxmap = Dict ([v => i for (i, v) in enumerate (unknowns (sys))])
108
106
cidxmap = Dict ([v => i for (i, v) in enumerate (controls (sys))])
109
107
110
- for st in unknowns (sys )
108
+ for st in unknowns (conssys )
111
109
x = operation (st)
112
110
t = only (arguments (st))
113
111
idx = stidxmap[x (iv)]
114
112
subval = isequal (t, iv) ? model[:U ][idx] : model[:U ][idx](t)
115
- jconstraints = Symbolics. substitute (jconstraints , Dict (x (t) => subval))
113
+ jconstraints = map (c -> Symbolics. substitute (c , Dict (x (t) => subval)), jconstraints )
116
114
end
117
115
118
116
for ct in controls (sys)
119
117
p = operation (ct)
120
118
t = only (arguments (ct))
121
119
idx = cidxmap[p (iv)]
122
120
subval = isequal (t, iv) ? model[:V ][idx] : model[:V ][idx](t)
123
- jconstraints = Symbolics. substitute (jconstraints, Dict (p (t) => subval))
121
+ jconstraints = map (c -> Symbolics. substitute (jconstraints, Dict (p (t) => subval)), jconstriants )
124
122
end
125
123
126
124
for (i, cons) in enumerate (jconstraints)
127
125
if cons isa Equation
128
- @constraint (model, user[i], cons. lhs - cons. rhs == 0 )
126
+ @constraint (model, cons. lhs - cons. rhs == 0 , base_name = " user[ $i ] " )
129
127
elseif cons. relational_op === Symbolics. geq
130
- @constraint (model, user[i], cons. lhs - cons. rhs ≥ 0 )
128
+ @constraint (model, cons. lhs - cons. rhs ≥ 0 , base_name = " user[ $i ] " )
131
129
else
132
- @constraint (model, user[i], cons. lhs - cons. rhs ≤ 0 )
130
+ @constraint (model, cons. lhs - cons. rhs ≤ 0 , base_name = " user[ $i ] " )
133
131
end
134
132
end
135
133
end
189
187
struct JuMPControlSolution
190
188
model:: InfiniteModel
191
189
sol:: ODESolution
190
+ input_sol:: Union{Nothing, ODESolution}
192
191
end
193
192
194
193
"""
@@ -213,7 +212,6 @@ function DiffEqBase.solve(prob::JuMPControlProblem, jump_solver, ode_solver::Sym
213
212
end
214
213
end
215
214
for var in all_variables (model)
216
- @show JuMP. name (var)
217
215
if occursin (" K" , JuMP. name (var))
218
216
unregister (model, Symbol (JuMP. name (var)))
219
217
delete (model, var)
@@ -232,10 +230,19 @@ function DiffEqBase.solve(prob::JuMPControlProblem, jump_solver, ode_solver::Sym
232
230
U_vals = [[U_vals[i][j] for i in 1 : length (U_vals)] for j in 1 : length (ts)]
233
231
sol = DiffEqBase. build_solution (prob, ode_solver, ts, U_vals)
234
232
233
+ input_sol = nothing
234
+ if ! isempty (model[:V ])
235
+ V_vals = value .(model[:V ])
236
+ V_vals = [[V_vals[i][j] for i in 1 : length (V_vals)] for j in 1 : length (ts)]
237
+ input_sol = DiffEqBase. build_solution (prob, ode_solver, ts, V_vals)
238
+ end
239
+
235
240
if ! (pstatus === FEASIBLE_POINT && (tstatus === OPTIMAL || tstatus === LOCALLY_SOLVED || tstatus === ALMOST_OPTIMAL || tstatus === ALMOST_LOCALLY_SOLVED))
236
241
sol = SciMLBase. solution_new_retcode (sol, SciMLBase. ReturnCode. ConvergenceFailure)
242
+ ! isnothing (input_sol) && (input_sol = SciMLBase. solution_new_retcode (input_sol, SciMLBase. ReturnCode. ConvergenceFailure))
237
243
end
238
- JuMPControlSolution (model, sol)
244
+
245
+ JuMPControlSolution (model, sol, input_sol)
239
246
end
240
247
241
248
end
0 commit comments