Skip to content

Commit 6c53436

Browse files
committed
implement ControlFunction
1 parent b2d00f6 commit 6c53436

File tree

5 files changed

+143
-15
lines changed

5 files changed

+143
-15
lines changed

ext/MTKJuMPControlExt.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ struct JuMPControlProblem{uType, tType, isinplace, P, F, K} <:
1515
kwargs::K
1616

1717
function JuMPControlProblem(f, u0, tspan, p, model; kwargs...)
18-
new{typeof(u0), typeof(tspan), SciMLBase.isinplace(f),
18+
new{typeof(u0), typeof(tspan), SciMLBase.isinplace(f, 5),
1919
typeof(p), typeof(f), typeof(kwargs)}(f, u0, tspan, p, model, kwargs)
2020
end
2121
end
@@ -55,13 +55,10 @@ function MTK.JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap;
5555
guesses = Dict(), kwargs...)
5656
MTK.warn_overdetermined(sys, u0map)
5757
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
58-
@show _u0map
59-
f, u0, p = MTK.process_SciMLProblem(ODEFunction, sys, _u0map, pmap;
58+
f, u0, p = MTK.process_SciMLProblem(ControlFunction, sys, _u0map, pmap;
6059
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
6160

62-
(f_i, f_o) = generate_control_function(sys)
6361
model = init_model(sys, tspan[1]:dt:tspan[2], u0map, u0)
64-
6562
JuMPControlProblem(f, u0, tspan, p, model, kwargs...)
6663
end
6764

@@ -80,7 +77,7 @@ function MTK.InfiniteOptControlProblem(sys::ODESystem, u0map, tspan, pmap;
8077
guesses = Dict(), kwargs...)
8178
MTK.warn_overdetermined(sys, u0map)
8279
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
83-
f, u0, p = MTK.process_SciMLProblem(ODEFunction, sys, _u0map, pmap;
80+
f, u0, p = MTK.process_SciMLProblem(ControlFunction, sys, _u0map, pmap;
8481
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
8582

8683
model = init_model(sys, tspan[1]:dt:tspan[2], u0map, u0)
@@ -237,14 +234,17 @@ function add_jump_solve_constraints!(prob, tableau)
237234
dt = tsteps[2] - tsteps[1]
238235

239236
U = model[:U]
237+
V = model[:V]
240238
nᵤ = length(U)
239+
nᵥ = length(V)
241240
if is_explicit(tableau)
242241
K = Any[]
243242
for τ in tsteps
244243
for (i, h) in enumerate(c)
245244
ΔU = sum([A[i, j] * K[j] for j in 1:(i - 1)], init = zeros(nᵤ))
246245
Uₙ = [U[i](τ) + ΔU[i] * dt for i in 1:nᵤ]
247-
Kₙ = f(Uₙ, p, τ + h * dt)
246+
Vₙ = [V[i](τ) for i in 1:nᵥ]
247+
Kₙ = f(Uₙ, Vₙ, p, τ + h * dt)
248248
push!(K, Kₙ)
249249
end
250250
ΔU = dt * sum([α[i] * K[i] for i in 1:length(α)])
@@ -259,7 +259,7 @@ function add_jump_solve_constraints!(prob, tableau)
259259
for (i, h) in enumerate(c)
260260
ΔU = ΔUs[i, :]
261261
Uₙ = [U[j] + ΔU[j] * dt for j in 1:nᵤ]
262-
@constraint(model, [j in 1:nᵤ], K[i, j]==f(Uₙ, p, τ + h * dt)[j],
262+
@constraint(model, [j in 1:nᵤ], K[i, j]==f(Uₙ, V, p, τ + h * dt)[j],
263263
DomainRestrictions(t => τ), base_name="solve_K()")
264264
end
265265
ΔU = dt * sum([α[i] * K[i, :] for i in 1:length(α)])

src/inputoutput.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,9 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
208208
inputs = [inputs; disturbance_inputs]
209209
end
210210

211-
sys, _ = io_preprocessing(sys, inputs, []; simplify, kwargs...)
211+
if !iscomplete(sys)
212+
sys, _ = io_preprocessing(sys, inputs, []; simplify, kwargs...)
213+
end
212214

213215
dvs = unknowns(sys)
214216
ps = parameters(sys; initial_parameters = true)
@@ -250,9 +252,7 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
250252
f = build_function_wrapper(sys, rhss, args...; p_start = 3 + implicit_dae,
251253
p_end = length(p) + 2 + implicit_dae)
252254
f = eval_or_rgf.(f; eval_expression, eval_module)
253-
f = GeneratedFunctionWrapper{(
254-
3 + implicit_dae, length(args) - length(p) + 1, is_split(sys))}(f...)
255-
f = f, f
255+
f = GeneratedFunctionWrapper{(3, length(args) - length(p) + 1, is_split(sys))}(f...)
256256
ps = setdiff(parameters(sys), inputs, disturbance_inputs)
257257
(; f, dvs, ps, io_sys = sys)
258258
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ function calculate_control_jacobian(sys::AbstractODESystem;
101101
end
102102

103103
rhs = [eq.rhs for eq in full_equations(sys)]
104-
ctrls = controls(sys)
104+
ctrls = unbound_inputs(sys)
105105

106106
if sparse
107107
jac = sparsejacobian(rhs, ctrls, simplify = simplify)

src/systems/optimal_control_interface.jl

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,134 @@ function warn_overdetermined(sys, u0map)
2020
end
2121
end
2222

23+
"""
24+
Generate the control function f(x, u, p, t) from the ODESystem.
25+
Input variables are automatically inferred but can be manually specified.
26+
"""
27+
function SciMLBase.ControlFunction{iip, specialize}(sys::ODESystem,
28+
dvs = unknowns(sys),
29+
ps = parameters(sys), u0 = nothing,
30+
inputs = unbound_inputs(sys),
31+
disturbance_inputs = disturbances(sys);
32+
version = nothing, tgrad = false,
33+
jac = false, controljac = false,
34+
p = nothing, t = nothing,
35+
eval_expression = false,
36+
sparse = false, simplify = false,
37+
eval_module = @__MODULE__,
38+
steady_state = false,
39+
checkbounds = false,
40+
sparsity = false,
41+
analytic = nothing,
42+
split_idxs = nothing,
43+
initialization_data = nothing,
44+
cse = true,
45+
kwargs...) where {iip, specialize}
46+
47+
(f), _, _ = generate_control_function(sys, inputs, disturbance_inputs; eval_expression = true, eval_module, cse, kwargs...)
48+
49+
if tgrad
50+
tgrad_gen = generate_tgrad(sys, dvs, ps;
51+
simplify = simplify,
52+
expression = Val{true},
53+
expression_module = eval_module, cse,
54+
checkbounds = checkbounds, kwargs...)
55+
tgrad_oop, tgrad_iip = eval_or_rgf.(tgrad_gen; eval_expression, eval_module)
56+
_tgrad = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(tgrad_oop, tgrad_iip)
57+
else
58+
_tgrad = nothing
59+
end
60+
61+
if jac
62+
jac_gen = generate_jacobian(sys, dvs, ps;
63+
simplify = simplify, sparse = sparse,
64+
expression = Val{true},
65+
expression_module = eval_module, cse,
66+
checkbounds = checkbounds, kwargs...)
67+
jac_oop, jac_iip = eval_or_rgf.(jac_gen; eval_expression, eval_module)
68+
69+
_jac = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(jac_oop, jac_iip)
70+
else
71+
_jac = nothing
72+
end
73+
74+
if controljac
75+
cjac_gen = generate_control_jacobian(sys, dvs, ps;
76+
simplify = simplify, sparse = sparse,
77+
expression = Val{true},
78+
expression_module = eval_module, cse,
79+
checkbounds = checkbounds, kwargs...)
80+
cjac_oop, cjac_iip = eval_or_rgf.(cjac_gen; eval_expression, eval_module)
81+
82+
_cjac = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(cjac_oop, cjac_iip)
83+
else
84+
_cjac = nothing
85+
end
86+
87+
M = calculate_massmatrix(sys)
88+
_M = if sparse && !(u0 === nothing || M === I)
89+
SparseArrays.sparse(M)
90+
elseif u0 === nothing || M === I
91+
M
92+
else
93+
ArrayInterface.restructure(u0 .* u0', M)
94+
end
95+
96+
observedfun = ObservedFunctionCache(
97+
sys; steady_state, eval_expression, eval_module, checkbounds, cse)
98+
99+
if sparse
100+
uElType = u0 === nothing ? Float64 : eltype(u0)
101+
W_prototype = similar(W_sparsity(sys), uElType)
102+
controljac_prototype = similar(calculate_control_jacobian(sys), uElType)
103+
else
104+
W_prototype = nothing
105+
controljac_prototype = nothing
106+
end
107+
108+
ControlFunction{iip, specialize}(f;
109+
sys = sys,
110+
jac = _jac === nothing ? nothing : _jac,
111+
controljac = _cjac === nothing ? nothing : _cjac,
112+
tgrad = _tgrad === nothing ? nothing : _tgrad,
113+
mass_matrix = _M,
114+
jac_prototype = W_prototype,
115+
controljac_prototype = controljac_prototype,
116+
observed = observedfun,
117+
sparsity = sparsity ? W_sparsity(sys) : nothing,
118+
analytic = analytic,
119+
initialization_data)
120+
end
121+
122+
function SciMLBase.ControlFunction(sys::AbstractODESystem, args...; kwargs...)
123+
ControlFunction{true}(sys, args...; kwargs...)
124+
end
125+
126+
function SciMLBase.ControlFunction{true}(sys::AbstractODESystem, args...;
127+
kwargs...)
128+
ControlFunction{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
129+
end
130+
131+
function SciMLBase.ControlFunction{false}(sys::AbstractODESystem, args...;
132+
kwargs...)
133+
ControlFunction{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
134+
end
135+
23136
"""
24137
IntegralNorm. When applied to an expression.
25138
"""
26139
struct IntegralNorm end
27140

141+
"""
142+
$(SIGNATURES)
143+
144+
Define one or more inputs.
145+
146+
See also [`@independent_variables`](@ref), [`@variables`](@ref) and [`@constants`](@ref).
147+
"""
148+
macro inputs(xs...)
149+
Symbolics._parse_vars(:inputs,
150+
Real,
151+
xs,
152+
toparam) |> esc
153+
end

test/extensions/jump_control.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ end
8686
end
8787

8888
# Double integrator
89+
t = M.t_nounits
90+
D = M.D_nounits
8991
@variables x(..) [bounds = (0., 0.25)] v(..)
9092
@variables u(t) [bounds = (-1., 1.), input = true]
9193
constr = [v(1.0) ~ 0.0]
@@ -106,8 +108,8 @@ end
106108
# Cart-pole system
107109

108110
# Bee example (from Lawrence Evans' notes)
109-
M.@variables w(..) q(..)
110-
M.@parameters α(t) [bounds = [0, 1]] b c μ s ν
111+
@variables w(..) q(..)
112+
@parameters α(t) [bounds = [0, 1]] b c μ s ν
111113

112114
tspan = (0, 4)
113115
eqs = [D(w(t)) ~ -μ*w(t) + b*s*α*w(t),

0 commit comments

Comments
 (0)