Skip to content

Commit c3ae7f6

Browse files
refactor: format
1 parent 88bd570 commit c3ae7f6

File tree

5 files changed

+112
-69
lines changed

5 files changed

+112
-69
lines changed

ext/MTKCasADiDynamicOptExt.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ struct MXLinearInterpolation
1717
t::Vector{Float64}
1818
dt::Float64
1919
end
20-
Base.getindex(m::MXLinearInterpolation, i...) = length(i) == length(size(m.u)) ? m.u[i...] : m.u[i..., :]
20+
function Base.getindex(m::MXLinearInterpolation, i...)
21+
length(i) == length(size(m.u)) ? m.u[i...] : m.u[i..., :]
22+
end
2123

2224
mutable struct CasADiModel
2325
model::Opti
@@ -55,7 +57,7 @@ function (M::MXLinearInterpolation)(τ)
5557
(i > length(M.t) || i < 1) && error("Cannot extrapolate past the tspan.")
5658
colons = ntuple(_ -> (:), length(size(M.u)) - 1)
5759
if i < length(M.t)
58-
M.u[colons..., i] + Δ*(M.u[colons..., i+1] - M.u[colons..., i])
60+
M.u[colons..., i] + Δ * (M.u[colons..., i + 1] - M.u[colons..., i])
5961
else
6062
M.u[colons..., i]
6163
end
@@ -65,7 +67,8 @@ function MTK.CasADiDynamicOptProblem(sys::System, op, tspan;
6567
dt = nothing,
6668
steps = nothing,
6769
guesses = Dict(), kwargs...)
68-
prob, _ = MTK.process_DynamicOptProblem(CasADiDynamicOptProblem, CasADiModel, sys, op, tspan; dt, steps, guesses, kwargs...)
70+
prob, _ = MTK.process_DynamicOptProblem(
71+
CasADiDynamicOptProblem, CasADiModel, sys, op, tspan; dt, steps, guesses, kwargs...)
6972
prob
7073
end
7174

@@ -127,10 +130,10 @@ function MTK.lowered_integral(model::CasADiModel, expr, lo, hi)
127130
for (i, t) in enumerate(model.U.t)
128131
if lo < t < hi
129132
Δt = min(dt, t - lo)
130-
total += (0.5*Δt*(expr[i] + expr[i-1]))
133+
total += (0.5 * Δt * (expr[i] + expr[i - 1]))
131134
elseif t >= hi && (t - dt < hi)
132135
Δt = hi - t + dt
133-
total += (0.5*Δt*(expr[i] + expr[i-1]))
136+
total += (0.5 * Δt * (expr[i] + expr[i - 1]))
134137
end
135138
end
136139
model.tₛ * total
@@ -186,9 +189,13 @@ struct CasADiCollocation <: AbstractCollocation
186189
tableau::DiffEqBase.ODERKTableau
187190
end
188191

189-
MTK.CasADiCollocation(solver, tableau = MTK.constructDefault()) = CasADiCollocation(solver, tableau)
192+
function MTK.CasADiCollocation(solver, tableau = MTK.constructDefault())
193+
CasADiCollocation(solver, tableau)
194+
end
190195

191-
function MTK.prepare_and_optimize!(prob::CasADiDynamicOptProblem, solver::CasADiCollocation; verbose = false, solver_options = Dict(), plugin_options = Dict(), kwargs...)
196+
function MTK.prepare_and_optimize!(
197+
prob::CasADiDynamicOptProblem, solver::CasADiCollocation; verbose = false,
198+
solver_options = Dict(), plugin_options = Dict(), kwargs...)
192199
solver_opti = add_solve_constraints!(prob, solver.tableau)
193200
verbose || (solver_options["print_level"] = 0)
194201
solver!(solver_opti, "$(solver.solver)", plugin_options, solver_options)
@@ -211,7 +218,7 @@ end
211218
function MTK.get_V_values(model::CasADiModel)
212219
value_getter = MTK.successful_solve(model) ? CasADi.debug_value : CasADi.value
213220
(nu, nt) = size(model.V.u)
214-
if nu*nt != 0
221+
if nu * nt != 0
215222
V_vals = value_getter(model.solver_opti, model.V.u)
216223
size(V_vals, 2) == 1 && (V_vals = V_vals')
217224
V_vals = [[V_vals[i, j] for i in 1:nu] for j in 1:nt]
@@ -224,9 +231,11 @@ function MTK.get_t_values(model::CasADiModel)
224231
value_getter = MTK.successful_solve(model) ? CasADi.debug_value : CasADi.value
225232
ts = value_getter(model.solver_opti, model.tₛ) .* model.U.t
226233
end
227-
MTK.objective_value(model::CasADiModel) = CasADi.pyconvert(Float64, model.solver_opti.py.value(model.solver_opti.py.f))
234+
function MTK.objective_value(model::CasADiModel)
235+
CasADi.pyconvert(Float64, model.solver_opti.py.value(model.solver_opti.py.f))
236+
end
228237

229-
function MTK.successful_solve(m::CasADiModel)
238+
function MTK.successful_solve(m::CasADiModel)
230239
isnothing(m.solver_opti) && return false
231240
retcode = CasADi.return_status(m.solver_opti)
232241
retcode == "Solve_Succeeded" || retcode == "Solved_To_Acceptable_Level"

ext/MTKInfiniteOptExt.jl

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -48,26 +48,32 @@ struct InfiniteOptDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
4848
end
4949

5050
MTK.generate_internal_model(m::Type{InfiniteOptModel}) = InfiniteModel()
51-
MTK.generate_time_variable!(m::InfiniteModel, tspan, tsteps) = @infinite_parameter(m, t in [tspan[1], tspan[2]], num_supports = length(tsteps))
52-
MTK.generate_state_variable!(m::InfiniteModel, u0::Vector, ns, ts) = @variable(m, U[i = 1:ns], Infinite(m[:t]), start=u0[i])
53-
MTK.generate_input_variable!(m::InfiniteModel, c0, nc, ts) = @variable(m, V[i = 1:nc], Infinite(m[:t]), start=c0[i])
51+
function MTK.generate_time_variable!(m::InfiniteModel, tspan, tsteps)
52+
@infinite_parameter(m, t in [tspan[1], tspan[2]], num_supports=length(tsteps))
53+
end
54+
function MTK.generate_state_variable!(m::InfiniteModel, u0::Vector, ns, ts)
55+
@variable(m, U[i = 1:ns], Infinite(m[:t]), start=u0[i])
56+
end
57+
function MTK.generate_input_variable!(m::InfiniteModel, c0, nc, ts)
58+
@variable(m, V[i = 1:nc], Infinite(m[:t]), start=c0[i])
59+
end
5460

5561
function MTK.generate_timescale!(m::InfiniteModel, guess, is_free_t)
56-
@variable(m, tₛ 0, start = guess)
62+
@variable(m, tₛ0, start=guess)
5763
if !is_free_t
58-
fix(tₛ, 1, force=true)
64+
fix(tₛ, 1, force = true)
5965
set_start_value(tₛ, 1)
6066
end
6167
tₛ
6268
end
6369

64-
function MTK.add_constraint!(m::InfiniteOptModel, expr::Union{Equation, Inequality})
70+
function MTK.add_constraint!(m::InfiniteOptModel, expr::Union{Equation, Inequality})
6571
if expr isa Equation
66-
@constraint(m.model, expr.lhs - expr.rhs == 0)
72+
@constraint(m.model, expr.lhs - expr.rhs==0)
6773
elseif expr.relational_op === Symbolics.geq
68-
@constraint(m.model, expr.lhs - expr.rhs 0)
74+
@constraint(m.model, expr.lhs - expr.rhs0)
6975
else
70-
@constraint(m.model, expr.lhs - expr.rhs 0)
76+
@constraint(m.model, expr.lhs - expr.rhs0)
7177
end
7278
end
7379
MTK.set_objective!(m::InfiniteOptModel, expr) = @objective(m.model, Min, expr)
@@ -76,20 +82,25 @@ function MTK.JuMPDynamicOptProblem(sys::System, op, tspan;
7682
dt = nothing,
7783
steps = nothing,
7884
guesses = Dict(), kwargs...)
79-
prob, _ = MTK.process_DynamicOptProblem(JuMPDynamicOptProblem, InfiniteOptModel, sys, op, tspan; dt, steps, guesses, kwargs...)
85+
prob, _ = MTK.process_DynamicOptProblem(JuMPDynamicOptProblem, InfiniteOptModel, sys,
86+
op, tspan; dt, steps, guesses, kwargs...)
8087
prob
8188
end
8289

8390
function MTK.InfiniteOptDynamicOptProblem(sys::System, op, tspan;
8491
dt = nothing,
8592
steps = nothing,
8693
guesses = Dict(), kwargs...)
87-
prob, pmap = MTK.process_DynamicOptProblem(InfiniteOptDynamicOptProblem, InfiniteOptModel, sys, op, tspan; dt, steps, guesses, kwargs...)
94+
prob, pmap = MTK.process_DynamicOptProblem(
95+
InfiniteOptDynamicOptProblem, InfiniteOptModel,
96+
sys, op, tspan; dt, steps, guesses, kwargs...)
8897
MTK.add_equational_constraints!(prob.wrapped_model, sys, pmap, tspan)
8998
prob
9099
end
91100

92-
MTK.lowered_integral(model::InfiniteOptModel, expr, lo, hi) = model.tₛ * InfiniteOpt.(expr, model.model[:t], lo, hi)
101+
function MTK.lowered_integral(model::InfiniteOptModel, expr, lo, hi)
102+
model.tₛ * InfiniteOpt.(expr, model.model[:t], lo, hi)
103+
end
93104
MTK.lowered_derivative(model::InfiniteOptModel, i) = (model.U[i], model.model[:t])
94105

95106
function MTK.process_integral_bounds(model::InfiniteOptModel, integral_span, tspan)
@@ -125,7 +136,7 @@ function add_solve_constraints!(prob::JuMPDynamicOptProblem, tableau)
125136
nᵥ = length(V)
126137
if MTK.is_explicit(tableau)
127138
K = Any[]
128-
for τ in tsteps[1:end-1]
139+
for τ in tsteps[1:(end - 1)]
129140
for (i, h) in enumerate(c)
130141
ΔU = sum([A[i, j] * K[j] for j in 1:(i - 1)], init = zeros(nᵤ))
131142
Uₙ = [U[i](τ) + ΔU[i] * dt for i in 1:nᵤ]
@@ -142,14 +153,15 @@ function add_solve_constraints!(prob::JuMPDynamicOptProblem, tableau)
142153
K = @variable(model, K[1:length(α), 1:nᵤ], Infinite(model[:t]))
143154
ΔUs = A * K
144155
ΔU_tot = dt * (K' * α)
145-
for τ in tsteps[1:end-1]
156+
for τ in tsteps[1:(end - 1)]
146157
for (i, h) in enumerate(c)
147158
ΔU = @view ΔUs[i, :]
148159
Uₙ = U + ΔU * dt
149160
@constraint(model, [j = 1:nᵤ], K[i, j]==(tₛ * f(Uₙ, V, p, τ + h * dt)[j]),
150161
DomainRestrictions(t => τ), base_name="solve_K$i()")
151162
end
152-
@constraint(model, [n = 1:nᵤ], U[n](τ) + ΔU_tot[n]==U[n](min+ dt, tsteps[end])),
163+
@constraint(model,
164+
[n = 1:nᵤ], U[n](τ) + ΔU_tot[n]==U[n](min+ dt, tsteps[end])),
153165
DomainRestrictions(t => τ), base_name="solve_U()")
154166
end
155167
end
@@ -159,15 +171,21 @@ struct JuMPCollocation <: AbstractCollocation
159171
solver::Any
160172
tableau::DiffEqBase.ODERKTableau
161173
end
162-
MTK.JuMPCollocation(solver, tableau = MTK.constructDefault()) = JuMPCollocation(solver, tableau)
174+
function MTK.JuMPCollocation(solver, tableau = MTK.constructDefault())
175+
JuMPCollocation(solver, tableau)
176+
end
163177

164178
struct InfiniteOptCollocation <: AbstractCollocation
165179
solver::Any
166180
derivative_method::InfiniteOpt.AbstractDerivativeMethod
167181
end
168-
MTK.InfiniteOptCollocation(solver, derivative_method = InfiniteOpt.FiniteDifference(InfiniteOpt.Backward())) = InfiniteOptCollocation(solver, derivative_method)
182+
function MTK.InfiniteOptCollocation(
183+
solver, derivative_method = InfiniteOpt.FiniteDifference(InfiniteOpt.Backward()))
184+
InfiniteOptCollocation(solver, derivative_method)
185+
end
169186

170-
function MTK.prepare_and_optimize!(prob::JuMPDynamicOptProblem, solver::JuMPCollocation; verbose = false, kwargs...)
187+
function MTK.prepare_and_optimize!(
188+
prob::JuMPDynamicOptProblem, solver::JuMPCollocation; verbose = false, kwargs...)
171189
model = prob.wrapped_model.model
172190
verbose || set_silent(model)
173191
# Unregister current solver constraints
@@ -190,7 +208,8 @@ function MTK.prepare_and_optimize!(prob::JuMPDynamicOptProblem, solver::JuMPColl
190208
model
191209
end
192210

193-
function MTK.prepare_and_optimize!(prob::InfiniteOptDynamicOptProblem, solver::InfiniteOptCollocation; verbose = false, kwargs...)
211+
function MTK.prepare_and_optimize!(prob::InfiniteOptDynamicOptProblem,
212+
solver::InfiniteOptCollocation; verbose = false, kwargs...)
194213
model = prob.wrapped_model.model
195214
verbose || set_silent(model)
196215
set_derivative_method(model[:t], solver.derivative_method)
@@ -223,8 +242,8 @@ function MTK.successful_solve(model::InfiniteModel)
223242
error("Model not solvable; please report this to github.com/SciML/ModelingToolkit.jl with a MWE.")
224243

225244
pstatus === FEASIBLE_POINT &&
226-
(tstatus === OPTIMAL || tstatus === LOCALLY_SOLVED || tstatus === ALMOST_OPTIMAL ||
227-
tstatus === ALMOST_LOCALLY_SOLVED)
245+
(tstatus === OPTIMAL || tstatus === LOCALLY_SOLVED || tstatus === ALMOST_OPTIMAL ||
246+
tstatus === ALMOST_LOCALLY_SOLVED)
228247
end
229248

230249
import InfiniteOpt: JuMP, GeneralVariableRef

ext/MTKPyomoDynamicOptExt.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ function MTK.add_constraint!(pmodel::PyomoDynamicOptModel, cons; n_idxs = 1)
108108
else
109109
cons.lhs - cons.rhs 0
110110
end
111-
expr = Symbolics.substitute(Symbolics.unwrap(expr), SPECIAL_FUNCTIONS_DICT, fold = false)
111+
expr = Symbolics.substitute(
112+
Symbolics.unwrap(expr), SPECIAL_FUNCTIONS_DICT, fold = false)
112113

113114
cons_sym = Symbol("cons", hash(cons))
114115
if occursin(Symbolics.unwrap(t_sym), expr)
@@ -141,17 +142,17 @@ end
141142
function MTK.lowered_integral(m::PyomoDynamicOptModel, arg, lo, hi)
142143
@unpack model, model_sym, t_sym, dummy_sym = m
143144
total = 0
144-
dt = Pyomo.pyconvert(Float64, (model.t.at(-1) - model.t.at(1))/(model.steps - 1))
145+
dt = Pyomo.pyconvert(Float64, (model.t.at(-1) - model.t.at(1)) / (model.steps - 1))
145146
f = Symbolics.build_function(arg, model_sym, t_sym, expression = false)
146147
for (i, t) in enumerate(model.t)
147148
if Bool(lo < t) && Bool(t < hi)
148-
t_p = model.t.at(i-1)
149+
t_p = model.t.at(i - 1)
149150
Δt = min(t - lo, t - t_p)
150-
total += 0.5*Δt*(f(model, t) + f(model, t_p))
151+
total += 0.5 * Δt * (f(model, t) + f(model, t_p))
151152
elseif Bool(t >= hi) && Bool(t - dt < hi)
152-
t_p = model.t.at(i-1)
153+
t_p = model.t.at(i - 1)
153154
Δt = hi - t + dt
154-
total += 0.5*Δt*(f(model, t) + f(model, t_p))
155+
total += 0.5 * Δt * (f(model, t) + f(model, t_p))
155156
end
156157
end
157158
PyomoVar(model.tₛ * total)

src/systems/optimal_control_interface.jl

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ function PyomoCollocation end
9393

9494
function warn_overdetermined(sys, op)
9595
cstrs = constraints(sys)
96-
init_conds = filter(x -> value(x) Set(unknowns(sys)), [k for (k,v) in op])
96+
init_conds = filter(x -> value(x) Set(unknowns(sys)), [k for (k, v) in op])
9797
if !isempty(cstrs)
9898
(length(cstrs) + length(init_conds) > length(unknowns(sys))) &&
9999
@warn "The control problem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by op) exceeds the total number of states. The solvers will default to doing a nonlinear least-squares optimization."
@@ -227,19 +227,19 @@ end
227227
##########################
228228
### MODEL CONSTRUCTION ###
229229
##########################
230-
function process_DynamicOptProblem(prob_type::Type{<:AbstractDynamicOptProblem}, model_type, sys::System, op, tspan;
231-
dt = nothing,
232-
steps = nothing,
233-
guesses = Dict(), kwargs...)
234-
230+
function process_DynamicOptProblem(
231+
prob_type::Type{<:AbstractDynamicOptProblem}, model_type, sys::System, op, tspan;
232+
dt = nothing,
233+
steps = nothing,
234+
guesses = Dict(), kwargs...)
235235
warn_overdetermined(sys, op)
236236
ctrls = unbound_inputs(sys)
237237
states = unknowns(sys)
238238

239239
stidxmap = Dict([v => i for (i, v) in enumerate(states)])
240240
op = Dict([default_toterm(value(k)) => v for (k, v) in op])
241241
u0_idxs = has_alg_eqs(sys) ? collect(1:length(states)) :
242-
[stidxmap[default_toterm(k)] for (k, v) in op if haskey(stidxmap, k)]
242+
[stidxmap[default_toterm(k)] for (k, v) in op if haskey(stidxmap, k)]
243243

244244
_op = has_alg_eqs(sys) ? op : merge(Dict(op), Dict(guesses))
245245
f, u0, p = process_SciMLProblem(ODEInputFunction, sys, _op;
@@ -302,7 +302,7 @@ function set_variable_bounds!(m, sys, pmap, tf)
302302
end
303303
end
304304

305-
is_free_final(model) = model.is_free_final
305+
is_free_final(model) = model.is_free_final
306306

307307
function add_cost_function!(model, sys, tspan, pmap)
308308
jcosts = cost(sys)
@@ -335,14 +335,15 @@ function substitute_integral(model, expr, tspan)
335335
Symbolics.substitute(expr, intmap)
336336
end
337337

338-
function process_integral_bounds(model, integral_span, tspan)
338+
function process_integral_bounds(model, integral_span, tspan)
339339
if is_free_final(model) && isequal(integral_span, tspan)
340340
integral_span = (0, 1)
341341
elseif is_free_final(model)
342342
error("Free final time problems cannot handle partial timespans.")
343343
else
344344
(lo, hi) = integral_span
345-
(lo < tspan[1] || hi > tspan[2]) && error("Integral bounds are beyond the timespan.")
345+
(lo < tspan[1] || hi > tspan[2]) &&
346+
error("Integral bounds are beyond the timespan.")
346347
integral_span
347348
end
348349
end
@@ -353,12 +354,14 @@ function substitute_model_vars(model, sys, exprs, tspan)
353354
c_ops = [operation(unwrap(ct)) for ct in unbound_inputs(sys)]
354355
t = get_iv(sys)
355356

356-
exprs = map(c -> Symbolics.fast_substitute(c, whole_t_map(model, t, x_ops, c_ops)), exprs)
357+
exprs = map(
358+
c -> Symbolics.fast_substitute(c, whole_t_map(model, t, x_ops, c_ops)), exprs)
357359

358360
(ti, tf) = tspan
359361
if symbolic_type(tf) === ScalarSymbolic()
360362
_tf = model.tₛ + ti
361-
exprs = map(c -> Symbolics.fast_substitute(c, free_t_map(model, tf, x_ops, c_ops)), exprs)
363+
exprs = map(
364+
c -> Symbolics.fast_substitute(c, free_t_map(model, tf, x_ops, c_ops)), exprs)
362365
exprs = map(c -> Symbolics.fast_substitute(c, Dict(tf => _tf)), exprs)
363366
end
364367
exprs = map(c -> Symbolics.fast_substitute(c, fixed_t_map(model, x_ops, c_ops)), exprs)
@@ -392,7 +395,8 @@ function fixed_t_map end
392395
function add_user_constraints!(model, sys, tspan, pmap)
393396
jconstraints = get_constraints(sys)
394397
(isnothing(jconstraints) || isempty(jconstraints)) && return nothing
395-
cons_dvs, cons_ps = process_constraint_system(jconstraints, Set(unknowns(sys)), parameters(sys), get_iv(sys); validate = false)
398+
cons_dvs, cons_ps = process_constraint_system(
399+
jconstraints, Set(unknowns(sys)), parameters(sys), get_iv(sys); validate = false)
396400

397401
is_free_final(model) && check_constraint_vars(cons_dvs)
398402

@@ -421,12 +425,13 @@ function add_equational_constraints!(model, sys, pmap, tspan)
421425
end
422426

423427
function set_objective! end
424-
objective_value(sol::DynamicOptSolution) = objective_value(sol.model)
428+
objective_value(sol::DynamicOptSolution) = objective_value(sol.model)
425429

426430
function substitute_differentials(model, sys, eqs)
427431
t = get_iv(sys)
428432
D = Differential(t)
429-
diffsubmap = Dict([D(lowered_var(model, :U, i, t)) => lowered_derivative(model, i) for i in 1:length(unknowns(sys))])
433+
diffsubmap = Dict([D(lowered_var(model, :U, i, t)) => lowered_derivative(model, i)
434+
for i in 1:length(unknowns(sys))])
430435
eqs = map(c -> Symbolics.substitute(c, diffsubmap), eqs)
431436
end
432437

@@ -466,9 +471,10 @@ function successful_solve end
466471
467472
- kwargs are used for other options. For example, the `plugin_options` and `solver_options` will propagated to the Opti object in CasADi.
468473
"""
469-
function DiffEqBase.solve(prob::AbstractDynamicOptProblem, solver::AbstractCollocation; verbose = false, kwargs...)
474+
function DiffEqBase.solve(prob::AbstractDynamicOptProblem,
475+
solver::AbstractCollocation; verbose = false, kwargs...)
470476
solved_model = prepare_and_optimize!(prob, solver; verbose, kwargs...)
471-
477+
472478
ts = get_t_values(solved_model)
473479
Us = get_U_values(solved_model)
474480
Vs = get_V_values(solved_model)

0 commit comments

Comments
 (0)