Skip to content

Commit ac08f9a

Browse files
committed
test: more test fixes
1 parent 353f12b commit ac08f9a

File tree

6 files changed

+122
-41
lines changed

6 files changed

+122
-41
lines changed

ext/MTKJuMPControlExt.jl

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ struct InfiniteOptControlProblem{uType, tType, isinplace, P, F, K} <:
3030
model::InfiniteModel
3131
kwargs::K
3232

33-
function InfiniteOptControlProblem(f, u0, tspan, p, model; kwargs...)
33+
function InfiniteOptControlProblem(f, u0, tspan, p, model, kwargs...)
3434
new{typeof(u0), typeof(tspan), SciMLBase.isinplace(f),
3535
typeof(p), typeof(f), typeof(kwargs)}(f, u0, tspan, p, model, kwargs)
3636
end
@@ -58,7 +58,7 @@ function MTK.JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap;
5858
guesses = Dict(), kwargs...)
5959
MTK.warn_overdetermined(sys, u0map)
6060
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
61-
f, u0, p = MTK.process_SciMLProblem(ControlFunction, sys, _u0map, pmap;
61+
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
6262
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
6363

6464
pmap = MTK.todict(pmap)
@@ -84,7 +84,7 @@ function MTK.InfiniteOptControlProblem(sys::ODESystem, u0map, tspan, pmap;
8484
guesses = Dict(), kwargs...)
8585
MTK.warn_overdetermined(sys, u0map)
8686
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
87-
f, u0, p = MTK.process_SciMLProblem(ControlFunction, sys, _u0map, pmap;
87+
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
8888
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
8989

9090
pmap = MTK.todict(pmap)
@@ -116,16 +116,17 @@ function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t = false)
116116

117117
@infinite_parameter(model, t in [tspan[1], tspan[2]], num_supports=steps)
118118
@variable(model, U[i = 1:length(states)], Infinite(t), start=u0[i])
119-
c0 = [pmap[c] for c in ctrls]
119+
c0 = MTK.value.([pmap[c] for c in ctrls])
120120
@variable(model, V[i = 1:length(ctrls)], Infinite(t), start=c0[i])
121121

122122
set_jump_bounds!(model, sys, pmap)
123123
add_jump_cost_function!(model, sys, (tspan[1], tspan[2]), pmap; is_free_t)
124124
add_user_constraints!(model, sys, pmap; is_free_t)
125125

126126
stidxmap = Dict([v => i for (i, v) in enumerate(states)])
127+
u0map = Dict([MTK.default_toterm(MTK.value(k)) => v for (k, v) in u0map])
127128
u0_idxs = has_alg_eqs(sys) ? collect(1:length(states)) :
128-
[stidxmap[k] for (k, v) in u0map]
129+
[stidxmap[MTK.default_toterm(k)] for (k, v) in u0map]
129130
add_initial_constraints!(model, u0, u0_idxs, tspan[1])
130131
return model
131132
end
@@ -190,7 +191,10 @@ function add_user_constraints!(model::InfiniteModel, sys, pmap; is_free_t = fals
190191
end
191192
end
192193

193-
jconstraints = substitute_jump_vars(model, sys, pmap, jconstraints)
194+
auxmap = Dict([u => MTK.default_toterm(MTK.value(u)) for u in unknowns(conssys)])
195+
jconstraints = substitute_jump_vars(model, sys, pmap, jconstraints; auxmap)
196+
197+
# Substitute to-term'd variables
194198
for (i, cons) in enumerate(jconstraints)
195199
if cons isa Equation
196200
@constraint(model, cons.lhs - cons.rhs==0, base_name="user[$i]")
@@ -207,25 +211,28 @@ function add_initial_constraints!(model::InfiniteModel, u0, u0_idxs, ts)
207211
@constraint(model, initial[i in u0_idxs], U[i](ts)==u0[i])
208212
end
209213

210-
function substitute_jump_vars(model, sys, pmap, exprs)
214+
function substitute_jump_vars(model, sys, pmap, exprs; auxmap = Dict())
211215
iv = MTK.get_iv(sys)
212216
sts = unknowns(sys)
213217
cts = MTK.unbound_inputs(sys)
214218
U = model[:U]
215219
V = model[:V]
220+
exprs = map(c -> Symbolics.fixpoint_sub(c, auxmap), exprs)
221+
216222
# for variables like x(t)
217223
whole_interval_map = Dict([[v => U[i] for (i, v) in enumerate(sts)];
218224
[v => V[i] for (i, v) in enumerate(cts)]])
219-
exprs = map(c -> Symbolics.substitute(c, whole_interval_map), exprs)
225+
exprs = map(c -> Symbolics.fixpoint_sub(c, whole_interval_map), exprs)
220226

221227
# for variables like x(1.0)
222228
x_ops = [MTK.operation(MTK.unwrap(st)) for st in sts]
223229
c_ops = [MTK.operation(MTK.unwrap(ct)) for ct in cts]
224230
fixed_t_map = Dict([[x_ops[i] => U[i] for i in 1:length(U)];
225231
[c_ops[i] => V[i] for i in 1:length(V)]])
226-
exprs = map(c -> Symbolics.substitute(c, fixed_t_map), exprs)
227232

228-
exprs = map(c -> Symbolics.substitute(c, Dict(pmap)), exprs)
233+
exprs = map(c -> Symbolics.fixpoint_sub(c, fixed_t_map), exprs)
234+
235+
exprs = map(c -> Symbolics.fixpoint_sub(c, Dict(pmap)), exprs)
229236
exprs
230237
end
231238

@@ -255,8 +262,10 @@ function add_jump_solve_constraints!(prob, tableau; is_free_t = false)
255262
model = prob.model
256263
f = prob.f
257264
p = prob.p
265+
258266
t = model[:t]
259-
tsteps = supports(model[:t])
267+
tsteps = supports(t)
268+
tmax = tsteps[end]
260269
pop!(tsteps)
261270
tₛ = is_free_t ? model[:tf] : 1
262271
dt = tsteps[2] - tsteps[1]
@@ -280,6 +289,7 @@ function add_jump_solve_constraints!(prob, tableau; is_free_t = false)
280289
base_name="solve_time_")
281290
empty!(K)
282291
end
292+
@show num_variables(model)
283293
else
284294
@variable(model, K[1:length(α), 1:nᵤ], Infinite(t), start=tsteps[1])
285295
ΔUs = A * K
@@ -288,10 +298,10 @@ function add_jump_solve_constraints!(prob, tableau; is_free_t = false)
288298
for (i, h) in enumerate(c)
289299
ΔU = @view ΔUs[i, :]
290300
Uₙ = U + ΔU * dt
291-
@constraint(model, [j = 1:nᵤ], K[i, j](τ)==tₛ * f(Uₙ, V, p, τ + h * dt)[j],
292-
DomainRestrictions(t => τ + h * dt), base_name="solve_K()")
301+
@constraint(model, [j = 1:nᵤ], K[i, j](τ)==(tₛ * f(Uₙ, V, p, τ + h * dt)[j]),
302+
DomainRestrictions(t => min(τ + h * dt, tmax)), base_name="solve_K()")
293303
end
294-
@constraint(model, [n = 1:nᵤ], U[n](τ) + ΔU_tot[n]==U[n](τ + dt),
304+
@constraint(model, [n = 1:nᵤ], U[n](τ) + ΔU_tot[n]==U[n](min(τ + dt, tmax)),
295305
DomainRestrictions(t => τ), base_name="solve_U()")
296306
end
297307
end
@@ -323,6 +333,7 @@ function DiffEqBase.solve(
323333
unregister(model, :K)
324334
for var in all_variables(model)
325335
if occursin("K", JuMP.name(var))
336+
unregister(model, Symbol(JuMP.name(var)))
326337
delete(model, var)
327338
end
328339
end

src/systems/diffeqs/odesystem.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ function ODESystem(eqs, iv; constraints = Equation[], costs = Num[], kwargs...)
406406
end
407407
costs = wrap.(costs)
408408

409-
return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars, consvars))),
409+
return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars))),
410410
collect(new_ps); constraintsystem, costs, kwargs...)
411411
end
412412

@@ -769,7 +769,9 @@ function process_constraint_system(
769769
constraintps = OrderedSet()
770770
for cons in constraints
771771
collect_vars!(constraintsts, constraintps, cons, iv)
772+
union!(constraintsts, collect_applied_operators(cons, Differential))
772773
end
774+
@show constraintsts
773775

774776
# Validate the states.
775777
validate_vars_and_find_ps!(constraintsts, constraintps, sts, iv)
@@ -811,6 +813,9 @@ function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv)
811813
elseif length(arguments(var)) > 1
812814
throw(ArgumentError("Too many arguments for variable $var."))
813815
elseif length(arguments(var)) == 1
816+
if iscall(var) && operation(var) isa Differential
817+
var = only(arguments(var))
818+
end
814819
arg = only(arguments(var))
815820
operation(var)(iv) sts ||
816821
throw(ArgumentError("Variable $var is not a variable of the ODESystem. Called variables must be variables of the ODESystem."))

src/systems/systems.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,11 @@ function __structural_simplify(
163163
end
164164
end
165165

166+
function toterm_auxsystems(system::ODESystem)
167+
constraints = system.constraintsystem.constraints
168+
169+
end
170+
166171
"""
167172
$(TYPEDSIGNATURES)
168173
Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
[deps]
2+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3+
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
4+
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
25
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
36
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
47
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
5-
Multibody = "e1cad5d1-98ef-44f9-a79a-9ca4547f95b9"
8+
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
9+
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
10+
OrdinaryDiffEqFIRK = "5960d6e9-dd7a-4743-88e7-cf307b64f125"
11+
OrdinaryDiffEqSDIRK = "2d112036-d095-4a1e-ab9a-08536f3ecdbf"
12+
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
13+
OrdinaryDiffEqVerner = "79d7bb75-1356-48c1-b8c0-6832512096c2"
14+
SimpleDiffEq = "05bca326-078c-5bf0-a5bf-ce7c7982d7fd"

test/dynamic_optimization/jump_control.jl

Lines changed: 65 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using ModelingToolkit
22
import JuMP, InfiniteOpt
33
using DiffEqDevTools, DiffEqBase
44
using SimpleDiffEq
5-
using OrdinaryDiffEqSDIRK
5+
using OrdinaryDiffEqSDIRK, OrdinaryDiffEqVerner, OrdinaryDiffEqTsit5, OrdinaryDiffEqFIRK
66
using Ipopt
77
using BenchmarkTools
88
using CairoMakie
@@ -100,8 +100,8 @@ end
100100
constr = [v(1.0) ~ 0.0]
101101
cost = [-x(1.0)] # Maximize the final distance.
102102
@named block = ODESystem(
103-
[D(x(t)) ~ v(t), D(v(t)) ~ u], t; costs = cost, constraints = constr)
104-
block, input_idxs = structural_simplify(block, ([u], []))
103+
[D(x(t)) ~ v(t), D(v(t)) ~ u(t)], t; costs = cost, constraints = constr)
104+
block, input_idxs = structural_simplify(block, ([u(t)], []))
105105

106106
u0map = [x(t) => 0.0, v(t) => 0.0]
107107
tspan = (0.0, 1.0)
@@ -113,19 +113,19 @@ end
113113
# Test reached final position.
114114
@test (jsol.sol.u[end][1], 0.25, rtol = 1e-5)
115115
# Test dynamics
116-
@parameters (u_interp::LinearInterpolation)(..)
117-
block_ode = ODESystem([D(x(t)) ~ v(t), D(v(t)) ~ u_interp(t)], t)
118-
spline = ctrl_to_spline(jsol.input_sol, LinearInterpolation)
119-
oprob = ODEProblem(block, u0map, tspan, [u_interp => spline])
120-
osol = solve(oprob, Vern8())
121-
@test jsol.sol.u osol.u
116+
@parameters (u_interp::ConstantInterpolation)(..)
117+
@mtkbuild block_ode = ODESystem([D(x(t)) ~ v(t), D(v(t)) ~ u_interp(t)], t)
118+
spline = ctrl_to_spline(jsol.input_sol, ConstantInterpolation)
119+
oprob = ODEProblem(block_ode, u0map, tspan, [u_interp => spline])
120+
osol = solve(oprob, Vern8(), dt = 0.01, adaptive = false)
121+
@test (jsol.sol.u, osol.u, rtol = 0.05)
122122

123123
iprob = InfiniteOptControlProblem(block, u0map, tspan, parammap; dt = 0.01)
124124
isol = solve(iprob, Ipopt.Optimizer; silent = true)
125125
@test is_bangbang(isol.input_sol, [-1.0], [1.0])
126126
@test (isol.sol.u[end][1], 0.25, rtol = 1e-5)
127-
osol = solve(oprob, ImplicitEuler())
128-
@test isol.sol.u osol.u
127+
osol = solve(oprob, ImplicitEuler(); dt = 0.01, adaptive = false)
128+
@test (isol.sol.u, osol.u, rtol = 0.05)
129129

130130
###################
131131
### Bee example ###
@@ -154,10 +154,12 @@ end
154154
@parameters (α_interp::LinearInterpolation)(..)
155155
eqs = [D(w(t)) ~ -μ * w(t) + b * s * α_interp(t) * w(t),
156156
D(q(t)) ~ -ν * q(t) + c * (1 - α_interp(t)) * s * w(t)]
157-
beesys_ode = ODESystem(eqs, t)
158-
oprob = ODEProblem(beesys_ode, u0map, tspan, [α_interp => ctrl_to_spline(jsol.input_sol, LinearInterpolation)])
159-
osol = solve(oprob, Tsit5())
160-
@test osol.u jsol.sol.u
157+
@mtkbuild beesys_ode = ODESystem(eqs, t)
158+
oprob = ODEProblem(beesys_ode, u0map, tspan, merge(Dict(pmap), Dict(α_interp => ctrl_to_spline(jsol.input_sol, LinearInterpolation))))
159+
osol = solve(oprob, Tsit5(); dt = 0.01, adaptive = false)
160+
@test (osol.u, jsol.sol.u, rtol = 0.01)
161+
osol2 = solve(oprob, ImplicitEuler(); dt = 0.01, adaptive = false)
162+
@test (osol2.u, isol.sol.u, rtol = 0.01)
161163
end
162164

163165
@testset "Rocket launch" begin
@@ -175,27 +177,31 @@ end
175177

176178
(ts, te) = (0.0, 0.2)
177179
costs = [-h(te)]
178-
constraints = [T(te) ~ 0]
179-
@named rocket = ODESystem(eqs, t; costs, constraints)
180+
cons = [T(te) ~ 0]
181+
@named rocket = ODESystem(eqs, t; costs, constraints = cons)
180182
rocket, input_idxs = structural_simplify(rocket, ([T(t)], []))
181183

182184
u0map = [h(t) => h₀, m(t) => m₀, v(t) => 0]
183185
pmap = [
184186
g₀ => 1, m₀ => 1.0, h_c => 500, c => 0.5 * (g₀ * h₀), D_c => 0.5 * 620 * m₀ / g₀,
185187
Tₘ => 3.5 * g₀ * m₀, T(t) => 0.0, h₀ => 1, m_c => 0.6]
186188
jprob = JuMPControlProblem(rocket, u0map, (ts, te), pmap; dt = 0.005, cse = false)
187-
jsol = solve(jprob, Ipopt.Optimizer, :RadauIA3)
189+
jsol = solve(jprob, Ipopt.Optimizer, :RadauIIA3)
188190
@test jsol.sol.u[end][1] > 1.012
191+
192+
iprob = InfiniteOptControlProblem(rocket, u0map, (ts, te), pmap; dt = 0.005, cse = false)
193+
isol = solve(iprob, Ipopt.Optimizer, derivative_method = OrthogonalCollocation(3))
194+
@test isol.sol.u[end][1] > 1.012
189195

190196
# Test solution
191197
@parameters (T_interp::CubicSpline)(..)
192198
eqs = [D(h(t)) ~ v(t),
193199
D(v(t)) ~ (T_interp(t) - drag(h(t), v(t))) / m(t) - gravity(h(t)),
194200
D(m(t)) ~ -T_interp(t) / c]
195-
rocket_ode = ODESystem(eqs, t)
196-
interpmap = Dict(T_interp => ctrl_to_spline(jsol.inputsol, CubicSpline))
197-
oprob = ODEProblem(rocket_ode, u0map, tspan, merge(pmap, interpmap))
198-
osol = solve(oprob, RadauIA3())
201+
@mtkbuild rocket_ode = ODESystem(eqs, t)
202+
interpmap = Dict(T_interp => ctrl_to_spline(jsol.input_sol, CubicSpline))
203+
oprob = ODEProblem(rocket_ode, u0map, (ts, te), merge(Dict(pmap), interpmap))
204+
osol = solve(oprob, RadauIIA3())
199205
@test jsol.sol.u osol.u
200206
end
201207

@@ -223,10 +229,44 @@ end
223229
@test isapprox(isol.sol.t[end], 10.0, rtol = 1e-3)
224230
end
225231

226-
using JuliaSimCompiler
227-
using Multibody.PlanarMechanics
228-
229232
@testset "Cart-pole problem" begin
233+
# gravity, length, moment of Inertia, drag coeff
234+
@parameters g l mₚ mₖ
235+
@variables x(..) θ(..) u(t) [input = true, bounds = (-10, 10)]
236+
237+
s = sin(θ(t))
238+
c = cos(θ(t))
239+
H = [mₖ+mₚ mₚ*l*c
240+
mₚ*l*c mₚ*l^2]
241+
C = [0 -mₚ*D(θ(t))*l*s
242+
0 0]
243+
qd = [D(x(t)), D(θ(t))]
244+
G = [0, mₚ*g*l*s]
245+
B = [1, 0]
246+
247+
tf = 5
248+
rhss = -H \ Vector(C*qd + G - B*u)
249+
eqs = [D(D(x(t))) ~ rhss[1], D(D(θ(t))) ~ rhss[2]]
250+
cons = [θ(tf) ~ π, x(tf) ~ 0, D(θ(tf)) ~ 0, D(x(tf)) ~ 0]
251+
costs = [(u^2)]
252+
tspan = (0, tf)
253+
254+
@named cartpole = ODESystem(eqs, t; costs, constraints = cons)
255+
cartpole, input_idxs = structural_simplify(cartpole, ([u], []))
256+
257+
u0map = [D(x(t)) => 0., D(θ(t)) => 0., θ(t) => 0., x(t) => 0.]
258+
pmap = [mₖ => 1., mₚ => 0.2, l => 0.5, g => 9.81, u => 0]
259+
jprob = JuMPControlProblem(cartpole, u0map, tspan, pmap; dt = 0.04)
260+
jsol = solve(jprob, Ipopt.Optimizer, :RK4)
261+
@test jsol.sol.u[end] [π, 0, 0, 0]
262+
263+
iprob = InfiniteOptControlProblem(cartpole, u0map, tspan, pmap; dt = 0.04)
264+
isol = solve(iprob, Ipopt.Optimizer)
265+
@test isol.sol.u[end] [π, 0, 0, 0]
266+
end
267+
268+
# RC Circuit
269+
@testset "MTK Components" begin
230270
end
231271

232272
#@testset "Constrained optimal control problems" begin

test/runtests.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ function activate_downstream_env()
2222
Pkg.instantiate()
2323
end
2424

25+
function activate_dynamic_optimization_env()
26+
Pkg.activate("dynamic_optimization")
27+
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
28+
Pkg.instantiate()
29+
end
30+
2531
@time begin
2632
if GROUP == "All" || GROUP == "InterfaceI"
2733
@testset "InterfaceI" begin
@@ -143,4 +149,9 @@ end
143149
@safetestset "InfiniteOpt Extension Test" include("extensions/test_infiniteopt.jl")
144150
@safetestset "JuMPControl Extension Test" include("extensions/jump_control.jl")
145151
end
152+
153+
if GROUP == "All" || GROUP == "Dynamic Optimization"
154+
activate_dynamic_optimization_env()
155+
@safetestset "JuMP Collocation Solvers" include("dynamic_optimization/jump_control")
156+
end
146157
end

0 commit comments

Comments
 (0)