Skip to content

Commit f2e3ffd

Browse files
committed
fix: use dt in the constraints
1 parent 1133e4e commit f2e3ffd

File tree

5 files changed

+59
-40
lines changed

5 files changed

+59
-40
lines changed

Project.toml

-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
3131
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
3232
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
3333
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
34-
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
3534
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
3635
JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5"
3736
Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316"
@@ -120,7 +119,6 @@ FunctionWrappersWrappers = "0.1"
120119
Graphs = "1.5.2"
121120
InfiniteOpt = "0.5"
122121
InteractiveUtils = "1"
123-
Ipopt = "1.8.0"
124122
JuMP = "1.25.0"
125123
JuliaFormatter = "1.0.47"
126124
JumpProcesses = "9.13.1"

ext/MTKJuMPControlExt.jl

+44-21
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using DiffEqDevTools, DiffEqBase, SciMLBase
55
using LinearAlgebra
66
const MTK = ModelingToolkit
77

8-
struct JuMPControlProblem{uType, tType, P, F, K}
8+
struct JuMPControlProblem{uType, tType, isinplace, P, F, K} <: SciMLBase.AbstractODEProblem{uType, tType, isinplace}
99
f::F
1010
u0::uType
1111
tspan::tType
@@ -14,7 +14,7 @@ struct JuMPControlProblem{uType, tType, P, F, K}
1414
kwargs::K
1515

1616
function JuMPControlProblem(f, u0, tspan, p, model; kwargs...)
17-
new{typeof(u0), typeof(tspan), typeof(p), typeof(f), typeof(kwargs)}(f, u0, tspan, p, model, kwargs)
17+
new{typeof(u0), typeof(tspan), SciMLBase.isinplace(f), typeof(p), typeof(f), typeof(kwargs)}(f, u0, tspan, p, model, kwargs)
1818
end
1919
end
2020

@@ -50,9 +50,9 @@ function MTK.JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("
5050
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
5151

5252
model = InfiniteModel()
53-
@infinite_parameter(model, t in [ts, te], num_supports = length(steps), derivative_method = OrthogonalCollocation(2))
54-
@variable(model, U[1:length(states)], Infinite(t), start = ts)
55-
@variable(model, V[1:length(ctrls)], Infinite(t), start = ts)
53+
@infinite_parameter(model, t in [ts, te], num_supports = length(steps))
54+
@variable(model, U[i = 1:length(states)], Infinite(t))
55+
@variable(model, V[1:length(ctrls)], Infinite(t))
5656

5757
add_jump_cost_function!(model, sys)
5858
add_user_constraints!(model, sys)
@@ -136,7 +136,8 @@ end
136136

137137
function add_initial_constraints!(model, u0, u0_idxs, tspan)
138138
ts = tspan[1]
139-
@constraint(model, init_u0_idx[i in u0_idxs], model[:U][i](ts) == u0[i])
139+
U = model[:U]
140+
@constraint(model, initial[i in u0_idxs], U[i](ts) == u0[i])
140141
end
141142

142143
is_explicit(tableau) = tableau isa DiffEqDevTools.ExplicitRKTableau
@@ -148,6 +149,7 @@ function add_solve_constraints!(prob, tableau)
148149
model = prob.model
149150
f = prob.f
150151
p = prob.p
152+
t = model[:t]
151153
tsteps = supports(model[:t])
152154
pop!(tsteps)
153155
dt = tsteps[2] - tsteps[1]
@@ -163,21 +165,21 @@ function add_solve_constraints!(prob, tableau)
163165
Kₙ = f(Uₙ, p, τ + h*dt)
164166
push!(K, Kₙ)
165167
end
166-
ΔU = sum([α[i] * K[i] for i in 1:length(α)])
167-
@constraint(model, [n = 1:nᵤ], U[n](τ) + ΔU[n] == U[n](τ + dt))
168+
ΔU = dt*sum([α[i] * K[i] for i in 1:length(α)])
169+
@constraint(model, [n = 1:nᵤ], U[n](τ) + ΔU[n] == U[n](τ + dt), base_name = "solve_time_")
168170
empty!(K)
169171
end
170172
else
171-
@variable(model, K[1:length(a), 1:nᵤ], Infinite(t), start = tsteps[1])
173+
@variable(model, K[1:length(α), 1:nᵤ], Infinite(t), start = tsteps[1])
172174
for τ in tsteps
173-
ΔUs = [A * K(τ)]
175+
ΔUs = A * K
174176
for (i, h) in enumerate(c)
175-
ΔU = ΔUs[i]
176-
Uₙ = [U[j](τ) + ΔU[j](τ)*dt for j in 1:nᵤ]
177-
@constraint(model, K[i](τ) == f(Uₙ, p, τ + h*dt))
177+
ΔU = ΔUs[i, :]
178+
Uₙ = [U[j] + ΔU[j]*dt for j in 1:nᵤ]
179+
@constraint(model, [j in 1:nᵤ], K[i, j] == f(Uₙ, p, τ + h*dt)[j], DomainRestrictions(t => τ), base_name = "solve_K()")
178180
end
179-
ΔU = sum([α[i] * K[i] for i in 1:length(α)])
180-
@constraint(model, U(τ) + dot(α, K(τ)) == U+ dt))
181+
ΔU = dt*sum([α[i] * K[i, :] for i in 1:length(α)])
182+
@constraint(model, [n = 1:nᵤ], U[n] + ΔU[n] == U[n](τ + dt), DomainRestrictions(t => τ), base_name = "solve_U()")
181183
end
182184
end
183185
end
@@ -194,25 +196,46 @@ Solve JuMPControlProblem. Arguments:
194196
- prob: a JumpControlProblem
195197
- jump_solver: a LP solver such as HiGHS
196198
- ode_solver: Takes in a symbol representing the solver. Acceptable solvers may be found at https://docs.sciml.ai/DiffEqDevDocs/stable/internals/tableaus/. Note that the symbol may be different than the typical name of the solver, e.g. :Tsitouras5 rather than Tsit5.
199+
200+
Returns a JuMPControlSolution, which contains both the model and the ODE solution.
197201
"""
198202
function DiffEqBase.solve(prob::JuMPControlProblem, jump_solver, ode_solver::Symbol)
199203
model = prob.model
200204
tableau_getter = Symbol(:construct, ode_solver)
201205
@eval tableau = $tableau_getter()
202206
ts = supports(model[:t])
207+
208+
# Unregister current solver constraints
209+
for con in all_constraints(model)
210+
if occursin("solve", JuMP.name(con))
211+
unregister(model, Symbol(JuMP.name(con)))
212+
delete(model, con)
213+
end
214+
end
215+
for var in all_variables(model)
216+
@show JuMP.name(var)
217+
if occursin("K", JuMP.name(var))
218+
unregister(model, Symbol(JuMP.name(var)))
219+
delete(model, var)
220+
end
221+
end
203222
add_solve_constraints!(prob, tableau)
204223

205224
set_optimizer(model, jump_solver)
206225
optimize!(model)
207226

208-
if is_solved_and_feasible(model)
209-
sol = DiffEqBase.build_solution(prob, ode_solver, ts, value.(U))
210-
JuMPControlSolution(model, sol)
211-
else
212-
sol = DiffEqBase.build_solution(prob, ode_solver, ts, value.(U))
227+
tstatus = termination_status(model)
228+
pstatus = primal_status(model)
229+
!has_values(model) && error("Model not solvable; please report this to github.com/SciML/ModelingToolkit.jl.")
230+
231+
U_vals = value.(model[:U])
232+
U_vals = [[U_vals[i][j] for i in 1:length(U_vals)] for j in 1:length(ts)]
233+
sol = DiffEqBase.build_solution(prob, ode_solver, ts, U_vals)
234+
235+
if !(pstatus === FEASIBLE_POINT && (tstatus === OPTIMAL || tstatus === LOCALLY_SOLVED || tstatus === ALMOST_OPTIMAL || tstatus === ALMOST_LOCALLY_SOLVED))
213236
sol = SciMLBase.solution_new_retcode(sol, SciMLBase.ReturnCode.ConvergenceFailure)
214-
JuMPControlSolution(model, sol)
215237
end
238+
JuMPControlSolution(model, sol)
216239
end
217240

218241
end

test/extensions/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
1414
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
1515
Nemo = "2edaba10-b0f1-5616-af89-8c11ac63239a"
1616
NonlinearSolveHomotopyContinuation = "2ac3b008-d579-4536-8c91-a1a5998c2f8b"
17-
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
1817
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
1918
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
2019
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
20+
SimpleDiffEq = "05bca326-078c-5bf0-a5bf-ce7c7982d7fd"
2121
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
2222
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
2323
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

test/extensions/ad.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
33
using Zygote
44
using SymbolicIndexingInterface
55
using SciMLStructures
6-
using OrdinaryDiffEq
6+
using OrdinaryDiffEqTsit5
77
using NonlinearSolve
88
using SciMLSensitivity
99
using ForwardDiff

test/extensions/jump_control.jl

+13-15
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ using ModelingToolkit
22
using JuMP, InfiniteOpt
33
using DiffEqDevTools, DiffEqBase
44
using SimpleDiffEq
5-
using HiGHS
5+
using OrdinaryDiffEqSDIRK
6+
using Ipopt
67
const M = ModelingToolkit
78

89
@testset "ODE Solution, no cost" begin
@@ -22,36 +23,33 @@ const M = ModelingToolkit
2223
# Test explicit method.
2324
jprob = JuMPControlProblem(sys, u0map, tspan, parammap, dt = 0.01)
2425
@test num_constraints(jprob.model) == 2 # initials
25-
jsol = solve(jprob, Ipopt.Optimizer, :Tsitouras5)
26+
jsol = solve(jprob, Ipopt.Optimizer, :RK4)
2627
oprob = ODEProblem(sys, u0map, tspan, parammap)
27-
osol = solve(oprob, SimpleTsit5(), adaptive = false)
28+
osol = solve(oprob, SimpleRK4(), dt = 0.01)
2829
@test jsol.sol.u osol.u
2930

3031
# Implicit method.
31-
jsol2 = solve(prob, Ipopt.Optimizer, :RK4)
32-
osol2 = solve(oprob, SimpleRK4(), adaptive = false)
33-
@test jsol2.sol.u osol2.u
32+
jsol2 = solve(jprob, Ipopt.Optimizer, :ImplicitEuler)
33+
osol2 = solve(oprob, ImplicitEuler(), dt = 0.01, adaptive = false)
34+
@test (jsol2.sol.u, osol2.u, rtol = 0.001)
3435

3536
# With a constraint
36-
@parameters α=1.5 β=1.0 γ=3.0 δ=1.0
37-
@variables x(..) y(..)
38-
39-
eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y(t),
40-
D(y(t)) ~ -γ * y(t) + δ * x(t) * y(t)]
41-
42-
u0map = []
43-
tspan = (0.0, 1.0)
37+
u0map = Pair[]
4438
guess = [x(t) => 4.0, y(t) => 2.0]
4539
constr = [x(0.6) ~ 3.5, x(0.3) ~ 7.0]
4640
@mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
4741

4842
jprob = JuMPControlProblem(sys, u0map, tspan, parammap; guesses, dt = 0.01)
4943
@test num_constraints(jprob.model) == 2 == num_variables(jprob.model) == 2
50-
jsol = solve(prob, HiGHS.Optimizer, :Tsitouras5)
44+
jsol = solve(prob, Ipopt.Optimizer, :Tsitouras5)
5145
sol = jsol.sol
5246
@test sol(0.6)[1] 3.5
5347
@test sol(0.3)[1] 7.0
5448
end
5549

5650
@testset "Optimal control problem" begin
51+
# Investing
52+
53+
54+
# Bang-bang control
5755
end

0 commit comments

Comments
 (0)