Skip to content

Commit cb42511

Browse files
committed
feat: free final time problems
1 parent 7f86e85 commit cb42511

File tree

7 files changed

+165
-78
lines changed

7 files changed

+165
-78
lines changed

ext/MTKJuMPControlExt.jl

+84-43
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using ModelingToolkit
33
using JuMP, InfiniteOpt
44
using DiffEqDevTools, DiffEqBase
55
using LinearAlgebra
6+
using StaticArrays
67
const MTK = ModelingToolkit
78

89
struct JuMPControlProblem{uType, tType, isinplace, P, F, K} <:
@@ -14,7 +15,7 @@ struct JuMPControlProblem{uType, tType, isinplace, P, F, K} <:
1415
model::InfiniteModel
1516
kwargs::K
1617

17-
function JuMPControlProblem(f, u0, tspan, p, model; kwargs...)
18+
function JuMPControlProblem(f, u0, tspan, p, model, kwargs...)
1819
new{typeof(u0), typeof(tspan), SciMLBase.isinplace(f, 5),
1920
typeof(p), typeof(f), typeof(kwargs)}(f, u0, tspan, p, model, kwargs)
2021
end
@@ -51,14 +52,18 @@ The constraints are:
5152
- The solver constraints that encode the time-stepping used by the solver
5253
"""
5354
function MTK.JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap;
54-
dt = error("dt must be provided for JuMPControlProblem."),
55+
dt = nothing,
56+
steps = nothing,
5557
guesses = Dict(), kwargs...)
5658
MTK.warn_overdetermined(sys, u0map)
5759
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
5860
f, u0, p = MTK.process_SciMLProblem(ControlFunction, sys, _u0map, pmap;
5961
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
6062

61-
model = init_model(sys, tspan[1]:dt:tspan[2], u0map, pmap, u0)
63+
pmap = MTK.todict(pmap)
64+
steps, is_free_t = MTK.process_tspan(tspan, dt, steps)
65+
model = init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t)
66+
6267
JuMPControlProblem(f, u0, tspan, p, model, kwargs...)
6368
end
6469

@@ -73,82 +78,115 @@ Related to `JuMPControlProblem`, but directly adds the differential equations
7378
of the system as derivative constraints, rather than using a solver tableau.
7479
"""
7580
function MTK.InfiniteOptControlProblem(sys::ODESystem, u0map, tspan, pmap;
76-
dt = error("dt must be provided for InfiniteOptControlProblem."),
81+
dt = nothing,
82+
steps = nothing,
7783
guesses = Dict(), kwargs...)
7884
MTK.warn_overdetermined(sys, u0map)
7985
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
8086
f, u0, p = MTK.process_SciMLProblem(ControlFunction, sys, _u0map, pmap;
8187
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
8288

83-
model = init_model(sys, tspan[1]:dt:tspan[2], u0map, pmap, u0)
84-
add_infopt_solve_constraints!(model, sys, pmap)
89+
pmap = MTK.todict(pmap)
90+
steps, is_free_t = MTK.process_tspan(tspan, dt, steps)
91+
model = init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t)
92+
93+
add_infopt_solve_constraints!(model, sys, pmap; is_free_t)
8594
InfiniteOptControlProblem(f, u0, tspan, p, model, kwargs...)
8695
end
8796

88-
function init_model(sys, tsteps, u0map, pmap, u0)
97+
# Initialize InfiniteOpt model.
98+
function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t = false)
8999
ctrls = MTK.unbound_inputs(sys)
90100
states = unknowns(sys)
91101
model = InfiniteModel()
92102

93-
@infinite_parameter(model, t in [tsteps[1], tsteps[end]], num_supports=length(tsteps))
94-
@variable(model, U[i = 1:length(states)], Infinite(t))
95-
@variable(model, V[1:length(ctrls)], Infinite(t))
103+
if is_free_t
104+
(ts_sym, te_sym) = tspan
105+
@variable(model, tf, start = pmap[te_sym])
106+
hasbounds(te_sym) && begin
107+
lo, hi = getbounds(te_sym)
108+
set_lower_bound(tf, lo)
109+
set_upper_bound(tf, hi)
110+
end
111+
pmap[ts_sym] = 0
112+
pmap[te_sym] = 1
113+
tspan = (0, 1)
114+
end
115+
116+
@infinite_parameter(model, t in [tspan[1], tspan[2]], num_supports = steps)
117+
@variable(model, U[i = 1:length(states)], Infinite(t), start = u0[i])
118+
c0 = [pmap[c] for c in ctrls]
119+
@variable(model, V[i = 1:length(ctrls)], Infinite(t), start = c0[i])
96120

97-
set_bounds!(model, sys)
98-
add_jump_cost_function!(model, sys, (tsteps[1], tsteps[2]), pmap)
99-
add_user_constraints!(model, sys, pmap)
121+
set_jump_bounds!(model, sys, pmap)
122+
add_jump_cost_function!(model, sys, (tspan[1], tspan[2]), pmap; is_free_t)
123+
add_user_constraints!(model, sys, pmap; is_free_t)
100124

101125
stidxmap = Dict([v => i for (i, v) in enumerate(states)])
102126
u0_idxs = has_alg_eqs(sys) ? collect(1:length(states)) :
103127
[stidxmap[k] for (k, v) in u0map]
104-
add_initial_constraints!(model, u0, u0_idxs, tsteps[1])
128+
add_initial_constraints!(model, u0, u0_idxs, tspan[1])
105129
return model
106130
end
107131

108-
function set_bounds!(model, sys)
132+
function set_jump_bounds!(model, sys, pmap)
109133
U = model[:U]
110134
for (i, u) in enumerate(unknowns(sys))
111-
lo, hi = MTK.getbounds(u)
112-
set_lower_bound(U[i], lo)
113-
set_upper_bound(U[i], hi)
135+
if MTK.hasbounds(u)
136+
lo, hi = MTK.getbounds(u)
137+
set_lower_bound(U[i], Symbolics.fixpoint_sub(lo, pmap))
138+
set_upper_bound(U[i], Symbolics.fixpoint_sub(hi, pmap))
139+
end
114140
end
115141

116142
V = model[:V]
117143
for (i, v) in enumerate(MTK.unbound_inputs(sys))
118-
lo, hi = MTK.getbounds(v)
119-
set_lower_bound(V[i], lo)
120-
set_upper_bound(V[i], hi)
144+
if MTK.hasbounds(v)
145+
lo, hi = MTK.getbounds(v)
146+
set_lower_bound(V[i], Symbolics.fixpoint_sub(lo, pmap))
147+
set_upper_bound(V[i], Symbolics.fixpoint_sub(hi, pmap))
148+
end
121149
end
122150
end
123151

124-
function add_jump_cost_function!(model::InfiniteModel, sys, tspan, pmap)
152+
function add_jump_cost_function!(model::InfiniteModel, sys, tspan, pmap; is_free_t = false)
125153
jcosts = MTK.get_costs(sys)
126154
consolidate = MTK.get_consolidate(sys)
127155
if isnothing(jcosts) || isempty(jcosts)
128156
@objective(model, Min, 0)
129157
return
130158
end
131159
jcosts = substitute_jump_vars(model, sys, pmap, jcosts)
160+
tₛ = is_free_t ? model[:tf] : 1
132161

133162
# Substitute integral
134163
iv = MTK.get_iv(sys)
135-
jcosts = map(c -> Symbolics.substitute(c, ∫ => Symbolics.Integral(iv in tspan)), jcosts)
164+
jcosts = map(c -> Symbolics.substitute(c, MTK.() => Symbolics.Integral(iv in tspan)), jcosts)
165+
136166
intmap = Dict()
137-
138167
for int in MTK.collect_applied_operators(jcosts, Symbolics.Integral)
168+
op = MTK.operation(int)
139169
arg = only(arguments(MTK.value(int)))
140-
lower_bound, upper_bound = (int.domain.domain.left, int.domain.domain.right)
141-
intmap[int] = InfiniteOpt.(arg, iv; lower_bound, upper_bound)
170+
lo, hi = (op.domain.domain.left, op.domain.domain.right)
171+
intmap[int] = tₛ * InfiniteOpt.(arg, model[:t], lo, hi)
142172
end
143173
jcosts = map(c -> Symbolics.substitute(c, intmap), jcosts)
144174
@objective(model, Min, consolidate(jcosts))
145175
end
146176

147-
function add_user_constraints!(model::InfiniteModel, sys, pmap)
177+
function add_user_constraints!(model::InfiniteModel, sys, pmap; is_free_t = false)
148178
conssys = MTK.get_constraintsystem(sys)
149179
jconstraints = isnothing(conssys) ? nothing : MTK.get_constraints(conssys)
150180
(isnothing(jconstraints) || isempty(jconstraints)) && return nothing
151181

182+
if is_free_t
183+
for u in MTK.get_unknowns(conssys)
184+
x = MTK.operation(u)
185+
t = only(arguments(u))
186+
MTK.symbolic_type(t) === NotSymbolic() && error("Provided specific time constraint in a free final time problem. This is not supported by the JuMP/InfiniteOpt collocation solvers. The offending variable is $u.")
187+
end
188+
end
189+
152190
jconstraints = substitute_jump_vars(model, sys, pmap, jconstraints)
153191
for (i, cons) in enumerate(jconstraints)
154192
if cons isa Equation
@@ -188,23 +226,24 @@ end
188226

189227
is_explicit(tableau) = tableau isa DiffEqDevTools.ExplicitRKTableau
190228

191-
function add_infopt_solve_constraints!(model::InfiniteModel, sys, pmap)
229+
function add_infopt_solve_constraints!(model::InfiniteModel, sys, pmap; is_free_t = false)
192230
# Differential equations
193231
U = model[:U]
194232
t = model[:t]
195233
D = Differential(MTK.get_iv(sys))
196234
diffsubmap = Dict([D(U[i]) => (U[i], t) for i in 1:length(U)])
235+
tₛ = is_free_t ? model[:tf] : 1
197236

198237
diff_eqs = substitute_jump_vars(model, sys, pmap, diff_equations(sys))
199238
diff_eqs = map(e -> Symbolics.substitute(e, diffsubmap), diff_eqs)
200-
@constraint(model, D[i = 1:length(diff_eqs)], diff_eqs[i].lhs==diff_eqs[i].rhs)
239+
@constraint(model, D[i = 1:length(diff_eqs)], diff_eqs[i].lhs == tₛ * diff_eqs[i].rhs)
201240

202241
# Algebraic equations
203242
alg_eqs = substitute_jump_vars(model, sys, pmap, alg_equations(sys))
204-
@constraint(model, A[i = 1:length(alg_eqs)], alg_eqs[i].lhs==alg_eqs[i].rhs)
243+
@constraint(model, A[i = 1:length(alg_eqs)], alg_eqs[i].lhs == alg_eqs[i].rhs)
205244
end
206245

207-
function add_jump_solve_constraints!(prob, tableau)
246+
function add_jump_solve_constraints!(prob, tableau; is_free_t = false)
208247
A = tableau.A
209248
α = tableau.α
210249
c = tableau.c
@@ -214,6 +253,7 @@ function add_jump_solve_constraints!(prob, tableau)
214253
t = model[:t]
215254
tsteps = supports(model[:t])
216255
pop!(tsteps)
256+
tₛ = is_free_t ? model[:tf] : 1
217257
dt = tsteps[2] - tsteps[1]
218258

219259
U = model[:U]
@@ -227,7 +267,7 @@ function add_jump_solve_constraints!(prob, tableau)
227267
ΔU = sum([A[i, j] * K[j] for j in 1:(i - 1)], init = zeros(nᵤ))
228268
Uₙ = [U[i](τ) + ΔU[i] * dt for i in 1:nᵤ]
229269
Vₙ = [V[i](τ) for i in 1:nᵥ]
230-
Kₙ = f(Uₙ, Vₙ, p, τ + h * dt)
270+
Kₙ = tₛ * f(Uₙ, Vₙ, p, τ + h * dt) # scale the time
231271
push!(K, Kₙ)
232272
end
233273
ΔU = dt * sum([α[i] * K[i] for i in 1:length(α)])
@@ -237,17 +277,17 @@ function add_jump_solve_constraints!(prob, tableau)
237277
end
238278
else
239279
@variable(model, K[1:length(α), 1:nᵤ], Infinite(t), start=tsteps[1])
280+
ΔUs = A * K
281+
ΔU_tot = dt * (K' * α)
240282
for τ in tsteps
241-
ΔUs = A * K
242283
for (i, h) in enumerate(c)
243-
ΔU = ΔUs[i, :]
244-
Uₙ = [U[j] + ΔU[j] * dt for j in 1:nᵤ]
245-
@constraint(model, [j in 1:nᵤ], K[i, j]==f(Uₙ, V, p, τ + h * dt)[j],
246-
DomainRestrictions(t => τ), base_name="solve_K()")
284+
ΔU = @view ΔUs[i, :]
285+
Uₙ = U + ΔU * dt
286+
@constraint(model, [j = 1:nᵤ], K[i, j](τ) == tₛ * f(Uₙ, V, p, τ + h * dt)[j],
287+
DomainRestrictions(t => τ + h*dt), base_name="solve_K()")
247288
end
248-
ΔU = dt * sum([α[i] * K[i, :] for i in 1:length(α)])
249-
@constraint(model, [n = 1:nᵤ], U[n] + ΔU[n]==U[n](τ + dt),
250-
DomainRestrictions(t => τ), base_name="solve_U()")
289+
@constraint(model, [n = 1:nᵤ], U[n](τ) + ΔU_tot[n] == U[n](τ + dt),
290+
DomainRestrictions(t => τ), base_name="solve_U()")
251291
end
252292
end
253293
end
@@ -281,7 +321,7 @@ function DiffEqBase.solve(
281321
delete(model, var)
282322
end
283323
end
284-
add_jump_solve_constraints!(prob, tableau)
324+
add_jump_solve_constraints!(prob, tableau; is_free_t = haskey(model, :tf))
285325
_solve(prob, jump_solver, ode_solver)
286326
end
287327

@@ -304,9 +344,10 @@ function _solve(prob::AbstractOptimalControlProblem, jump_solver, solver)
304344
tstatus = termination_status(model)
305345
pstatus = primal_status(model)
306346
!has_values(model) &&
307-
error("Model not solvable; please report this to github.com/SciML/ModelingToolkit.jl.")
347+
error("Model not solvable; please report this to github.com/SciML/ModelingToolkit.jl with a MWE.")
308348

309-
ts = supports(model[:t])
349+
tf = haskey(model, :tf) ? value(model[:tf]) : 1
350+
ts = tf * supports(model[:t])
310351
U_vals = value.(model[:U])
311352
U_vals = [[U_vals[i][j] for i in 1:length(U_vals)] for j in 1:length(ts)]
312353
sol = DiffEqBase.build_solution(prob, solver, ts, U_vals)

src/ModelingToolkit.jl

+1
Original file line numberDiff line numberDiff line change
@@ -351,5 +351,6 @@ include("systems/optimal_control_interface.jl")
351351
export AbstractOptimalControlProblem, JuMPControlProblem, InfiniteOptControlProblem,
352352
PyomoControlProblem, CasADiControlProblem
353353
export OptimalControlSolution
354+
export
354355

355356
end # module

src/inputoutput.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
250250
args = (ddvs, args...)
251251
end
252252
f = build_function_wrapper(sys, rhss, args...; p_start = 3 + implicit_dae,
253-
p_end = length(p) + 2 + implicit_dae)
253+
p_end = length(p) + 2 + implicit_dae, kwargs...)
254254
f = eval_or_rgf.(f; eval_expression, eval_module)
255255
f = GeneratedFunctionWrapper{(3, length(args) - length(p) + 1, is_split(sys))}(f...)
256256
ps = setdiff(parameters(sys), inputs, disturbance_inputs)

src/systems/diffeqs/odesystem.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,7 @@ function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv)
819819
arg isa AbstractFloat ||
820820
throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds."))
821821

822-
isparameter(arg) && push!(auxps, arg)
822+
(isparameter(arg) && !isequal(arg, iv)) && push!(auxps, arg)
823823
else
824824
var sts &&
825825
@warn "Variable $var has no argument. It will be interpreted as $var($iv), and the constraint will apply to the entire interval."

src/systems/optimal_control_interface.jl

+50-12
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@ struct OptimalControlSolution
77
input_sol::Union{Nothing, ODESolution}
88
end
99

10+
function Base.show(io::IO, sol::OptimalControlSolution)
11+
println("retcode: ", sol.sol.retcode, "\n")
12+
13+
println("Optimal control solution for following model:\n")
14+
show(sol.model)
15+
16+
print("\n\nPlease query the model using sol.model, the solution trajectory for the system using sol.sol, or the solution trajectory for the controllers using sol.input_sol.")
17+
end
18+
1019
function JuMPControlProblem end
1120
function InfiniteOptControlProblem end
1221
function CasADiControlProblem end
@@ -44,7 +53,7 @@ function SciMLBase.ControlFunction{iip, specialize}(sys::ODESystem,
4453
cse = true,
4554
kwargs...) where {iip, specialize}
4655

47-
(f), _, _ = generate_control_function(sys, inputs, disturbance_inputs; eval_expression = true, eval_module, cse, kwargs...)
56+
(f), _, _ = generate_control_function(sys, inputs, disturbance_inputs; eval_module, cse, kwargs...)
4857

4958
if tgrad
5059
tgrad_gen = generate_tgrad(sys, dvs, ps;
@@ -134,7 +143,7 @@ function SciMLBase.ControlFunction{false}(sys::AbstractODESystem, args...;
134143
end
135144

136145
"""
137-
IntegralNorm. When applied to an expression in a cost
146+
Integral operator. When applied to an expression in a cost
138147
function, assumes that the integration variable is the
139148
iv of the system, and assumes that the bounds are the
140149
tspan.
@@ -143,17 +152,46 @@ Equivalent to Integral(t in tspan) in Symbolics.
143152
struct<: Symbolics.Operator end
144153
(x) = ()(x)
145154
Base.show(io::IO, x::∫) = print(io, "")
155+
Base.nameof(::∫) = :∫
146156

147-
"""
148-
$(SIGNATURES)
157+
function (I::∫)(x)
158+
Term{symtype(x)}(I, Any[x])
159+
end
160+
161+
function (I::∫)(x::Num)
162+
v = value(x)
163+
Num(I(v))
164+
end
149165

150-
Define one or more inputs.
166+
SymbolicUtils.promote_symtype(::Int, t) = t
151167

152-
See also [`@independent_variables`](@ref), [`@variables`](@ref) and [`@constants`](@ref).
153-
"""
154-
macro inputs(xs...)
155-
Symbolics._parse_vars(:inputs,
156-
Real,
157-
xs,
158-
toparam) |> esc
168+
# returns the JuMP timespan, the number of steps, and whether it is a free time problem.
169+
function process_tspan(tspan, dt, steps)
170+
is_free_time = false
171+
if isnothing(dt) && isnothing(steps)
172+
error("Must provide either the dt or the number of intervals to the collocation solvers (JuMP, InfiniteOpt, CasADi).")
173+
elseif symbolic_type(tspan[1]) === ScalarSymbolic() || symbolic_type(tspan[2]) === ScalarSymbolic()
174+
isnothing(steps) && error("Free final time problems require specifying the number of steps, rather than dt.")
175+
isnothing(dt) || @warn "Specified dt for free final time problem. This will be ignored; dt will be determined by the number of timesteps."
176+
177+
return steps, true
178+
else
179+
isnothing(steps) || @warn "Specified number of steps for problem with concrete tspan. This will be ignored; number of steps will be determined by dt."
180+
181+
return length(tspan[1]:dt:tspan[2]), false
182+
end
159183
end
184+
185+
#"""
186+
#$(SIGNATURES)
187+
#
188+
#Define one or more inputs.
189+
#
190+
#See also [`@independent_variables`](@ref), [`@variables`](@ref) and [`@constants`](@ref).
191+
#"""
192+
#macro inputs(xs...)
193+
# Symbolics._parse_vars(:inputs,
194+
# Real,
195+
# xs,
196+
# toparam) |> esc
197+
#end

src/variables.jl

+5
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,11 @@ function hasbounds(x)
332332
any(isfinite.(b[1]) .|| isfinite.(b[2]))
333333
end
334334

335+
function setbounds(x::Num, bounds)
336+
(lb, ub) = bounds
337+
setmetadata(x, VariableBounds, (lb, ub))
338+
end
339+
335340
## Disturbance =================================================================
336341
struct VariableDisturbance end
337342
Symbolics.option_to_metadata_type(::Val{:disturbance}) = VariableDisturbance

0 commit comments

Comments
 (0)