Skip to content

Commit 8c6b5ad

Browse files
Merge pull request #3584 from vyudu/opt_ctrl_utils
feat: add `costs, constraints, coalesce` to `@mtkmodel`, introduce `At` operator
2 parents b521932 + a30df65 commit 8c6b5ad

File tree

8 files changed

+165
-34
lines changed

8 files changed

+165
-34
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ function var_derivative_graph! end
142142
include("bipartite_graph.jl")
143143
using .BipartiteGraphs
144144

145+
export EvalAt
145146
include("variables.jl")
146147
include("parameters.jl")
147148
include("independent_variables.jl")

src/systems/diffeqs/odesystem.jl

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ end
247247
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
248248
controls = Num[],
249249
observed = Equation[],
250-
constraintsystem = nothing,
250+
constraints = Any[],
251251
costs = Num[],
252252
consolidate = nothing,
253253
systems = ODESystem[],
@@ -276,11 +276,29 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
276276
name === nothing &&
277277
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
278278
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."
279+
280+
constraintsystem = nothing
281+
if !isempty(constraints)
282+
constraintsystem = process_constraint_system(constraints, dvs, ps, iv)
283+
for p in parameters(constraintsystem)
284+
!in(p, Set(ps)) && push!(ps, p)
285+
end
286+
end
287+
288+
if !isempty(costs)
289+
coststs, costps = process_costs(costs, dvs, ps, iv)
290+
for p in costps
291+
!in(p, Set(ps)) && push!(ps, p)
292+
end
293+
end
294+
costs = wrap.(costs)
295+
279296
iv′ = value(iv)
280297
ps′ = value.(ps)
281298
ctrl′ = value.(controls)
282299
dvs′ = value.(dvs)
283300
dvs′ = filter(x -> !isdelay(x, iv), dvs′)
301+
284302
parameter_dependencies, ps′ = process_parameter_dependencies(
285303
parameter_dependencies, ps′)
286304
if !(isempty(default_u0) && isempty(default_p))
@@ -350,7 +368,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
350368
metadata, gui_metadata, is_dde, tstops, checks = checks)
351369
end
352370

353-
function ODESystem(eqs, iv; constraints = Equation[], costs = Num[], kwargs...)
371+
function ODESystem(eqs, iv; kwargs...)
354372
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)
355373

356374
for eq in get(kwargs, :parameter_dependencies, Equation[])
@@ -382,30 +400,8 @@ function ODESystem(eqs, iv; constraints = Equation[], costs = Num[], kwargs...)
382400
end
383401
algevars = setdiff(allunknowns, diffvars)
384402

385-
consvars = OrderedSet()
386-
constraintsystem = nothing
387-
if !isempty(constraints)
388-
constraintsystem = process_constraint_system(constraints, allunknowns, new_ps, iv)
389-
for st in get_unknowns(constraintsystem)
390-
iscall(st) ?
391-
!in(operation(st)(iv), allunknowns) && push!(consvars, st) :
392-
!in(st, allunknowns) && push!(consvars, st)
393-
end
394-
for p in parameters(constraintsystem)
395-
!in(p, new_ps) && push!(new_ps, p)
396-
end
397-
end
398-
399-
if !isempty(costs)
400-
coststs, costps = process_costs(costs, allunknowns, new_ps, iv)
401-
for p in costps
402-
!in(p, new_ps) && push!(new_ps, p)
403-
end
404-
end
405-
costs = wrap.(costs)
406-
407-
return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars, consvars))),
408-
collect(new_ps); constraintsystem, costs, kwargs...)
403+
return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars))),
404+
collect(new_ps); kwargs...)
409405
end
410406

411407
# NOTE: equality does not check cached Jacobian
@@ -760,7 +756,7 @@ end
760756
Build the constraint system for the ODESystem.
761757
"""
762758
function process_constraint_system(
763-
constraints::Vector{Equation}, sts, ps, iv; consname = :cons)
759+
constraints::Vector, sts, ps, iv; consname = :cons)
764760
isempty(constraints) && return nothing
765761

766762
constraintsts = OrderedSet()
@@ -800,7 +796,7 @@ Return the set of additional parameters found in the system, e.g. in x(p) ~ 3 th
800796
parameter of the system.
801797
"""
802798
function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv)
803-
sts = sysvars
799+
sts = Set(sysvars)
804800

805801
for var in auxvars
806802
if !iscall(var)

src/systems/discrete_system/discrete_system.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ struct DiscreteSystem <: AbstractDiscreteSystem
121121
tearing_state = nothing, substitutions = nothing, namespacing = true,
122122
complete = false, index_cache = nothing, parent = nothing,
123123
isscheduled = false;
124-
checks::Union{Bool, Int} = true)
124+
checks::Union{Bool, Int} = true, kwargs...)
125125
if checks == true || (checks & CheckComponents) > 0
126126
check_independent_variables([iv])
127127
check_variables(dvs, iv)
@@ -199,7 +199,7 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
199199
DiscreteSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
200200
eqs, iv′, dvs′, ps′, tspan, var_to_name, observed, name, description, systems,
201201
defaults, guesses, initializesystem, initialization_eqs, preface, connector_type,
202-
parameter_dependencies, metadata, gui_metadata, kwargs...)
202+
parameter_dependencies, metadata, gui_metadata)
203203
end
204204

205205
function DiscreteSystem(eqs, iv; kwargs...)

src/systems/model_parsing.jl

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector)
6565
ps, sps, vs, = [], [], []
6666
c_evts = []
6767
d_evts = []
68+
cons = []
69+
costs = []
6870
kwargs = OrderedCollections.OrderedSet()
6971
where_types = Union{Symbol, Expr}[]
7072

@@ -80,7 +82,7 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector)
8082
for arg in expr.args
8183
if arg.head == :macrocall
8284
parse_model!(exprs.args, comps, ext, eqs, icon, vs, ps,
83-
sps, c_evts, d_evts, dict, mod, arg, kwargs, where_types)
85+
sps, c_evts, d_evts, cons, costs, dict, mod, arg, kwargs, where_types)
8486
elseif arg.head == :block
8587
push!(exprs.args, arg)
8688
elseif arg.head == :if
@@ -120,13 +122,15 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector)
120122
gui_metadata = isassigned(icon) > 0 ? GUIMetadata(GlobalRef(mod, name), icon[]) :
121123
GUIMetadata(GlobalRef(mod, name))
122124

125+
consolidate = get(dict, :consolidate, nothing)
123126
description = get(dict, :description, "")
124127

125128
@inline pop_structure_dict!.(
126129
Ref(dict), [:constants, :defaults, :kwargs, :structural_parameters])
127130

128131
sys = :($type($(flatten_equations)(equations), $iv, variables, parameters;
129-
name, description = $description, systems, gui_metadata = $gui_metadata, defaults))
132+
name, description = $description, systems, gui_metadata = $gui_metadata, defaults,
133+
costs = [$(costs...)], constraints = [$(cons...)], consolidate = $consolidate))
130134

131135
if length(ext) == 0
132136
push!(exprs.args, :(var"#___sys___" = $sys))
@@ -610,7 +614,7 @@ function get_var(mod::Module, b)
610614
end
611615

612616
function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps, c_evts, d_evts,
613-
dict, mod, arg, kwargs, where_types)
617+
cons, costs, dict, mod, arg, kwargs, where_types)
614618
mname = arg.args[1]
615619
body = arg.args[end]
616620
if mname == Symbol("@description")
@@ -638,6 +642,12 @@ function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps, c_evts, d_evts,
638642
parse_icon!(body, dict, icon, mod)
639643
elseif mname == Symbol("@defaults")
640644
parse_system_defaults!(exprs, arg, dict)
645+
elseif mname == Symbol("@constraints")
646+
parse_costs!(cons, dict, body)
647+
elseif mname == Symbol("@costs")
648+
parse_constraints!(costs, dict, body)
649+
elseif mname == Symbol("@consolidate")
650+
parse_consolidate!(body, dict)
641651
else
642652
error("$mname is not handled.")
643653
end
@@ -1149,6 +1159,32 @@ function parse_discrete_events!(d_evts, dict, body)
11491159
end
11501160
end
11511161

1162+
function parse_constraints!(cons, dict, body)
1163+
dict[:constraints] = []
1164+
Base.remove_linenums!(body)
1165+
for arg in body.args
1166+
push!(cons, arg)
1167+
push!(dict[:constraints], readable_code.(cons)...)
1168+
end
1169+
end
1170+
1171+
function parse_costs!(costs, dict, body)
1172+
dict[:costs] = []
1173+
Base.remove_linenums!(body)
1174+
for arg in body.args
1175+
push!(costs, arg)
1176+
push!(dict[:costs], readable_code.(costs)...)
1177+
end
1178+
end
1179+
1180+
function parse_consolidate!(body, dict)
1181+
if !(occursin("->", string(body)) || occursin("=", string(body)))
1182+
error("Consolidate must be a function definition.")
1183+
else
1184+
dict[:consolidate] = body
1185+
end
1186+
end
1187+
11521188
function parse_icon!(body::String, dict, icon, mod)
11531189
icon_dir = get(ENV, "MTK_ICONS_DIR", joinpath(DEPOT_PATH[1], "mtk_icons"))
11541190
dict[:icon] = icon[] = if isfile(body)

src/systems/unit_check.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ function validate(jumps::ArrayPartition{<:Union{Any, Vector{<:JumpType}}}, t::Sy
272272
all([validate(jumps.x[idx], t, info = labels[idx]) for idx in 1:3])
273273
end
274274

275-
function validate(eq::Equation; info::String = "")
275+
function validate(eq::Union{Inequality, Equation}; info::String = "")
276276
if typeof(eq.lhs) == Connection
277277
_validate(eq.rhs; info)
278278
else

src/variables.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,3 +612,43 @@ getunshifted(x::Symbolic) = Symbolics.getmetadata(x, VariableUnshifted, nothing)
612612

613613
getshift(x::Num) = getshift(unwrap(x))
614614
getshift(x::Symbolic) = Symbolics.getmetadata(x, VariableShift, 0)
615+
616+
###################
617+
### Evaluate at ###
618+
###################
619+
struct EvalAt <: Symbolics.Operator
620+
t::Union{Symbolic, Number}
621+
end
622+
623+
function (A::EvalAt)(x::Symbolic)
624+
if symbolic_type(x) == NotSymbolic() || !iscall(x)
625+
if x isa Symbolics.CallWithMetadata
626+
return x(A.t)
627+
else
628+
return x
629+
end
630+
end
631+
632+
if iscall(x) && operation(x) == getindex
633+
arr = arguments(x)[1]
634+
term(getindex, A(arr), arguments(x)[2:end]...)
635+
elseif operation(x) isa Differential
636+
x = default_toterm(x)
637+
A(x)
638+
else
639+
length(arguments(x)) !== 1 &&
640+
error("Variable $x has too many arguments. EvalAt can only be applied to one-argument variables.")
641+
(symbolic_type(only(arguments(x))) !== ScalarSymbolic()) && return x
642+
return operation(x)(A.t)
643+
end
644+
end
645+
646+
function (A::EvalAt)(x::Union{Num, Symbolics.Arr})
647+
wrap(A(unwrap(x)))
648+
end
649+
SymbolicUtils.isbinop(::EvalAt) = false
650+
651+
Base.nameof(::EvalAt) = :EvalAt
652+
Base.show(io::IO, A::EvalAt) = print(io, "EvalAt(", A.t, ")")
653+
Base.:(==)(A1::EvalAt, A2::EvalAt) = isequal(A1.t, A2.t)
654+
Base.hash(A::EvalAt, u::UInt) = hash(A.t, u)

test/model_parsing.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,3 +1026,35 @@ end
10261026
@named sys = Float2Bool()
10271027
@test typeof(sys) == DiscreteSystem
10281028
end
1029+
1030+
@testset "Constraints, costs, consolidate" begin
1031+
@mtkmodel Example begin
1032+
@variables begin
1033+
x(t)
1034+
y(t)
1035+
end
1036+
@equations begin
1037+
x ~ y
1038+
end
1039+
@constraints begin
1040+
EvalAt(0.3)(x) ~ 3
1041+
y 4
1042+
end
1043+
@costs begin
1044+
x + y
1045+
EvalAt(1)(y)^2
1046+
end
1047+
@consolidate f(u) = u[1]^2 + log(u[2])
1048+
end
1049+
1050+
@named ex = Example()
1051+
ex = complete(ex)
1052+
1053+
costs = ModelingToolkit.get_costs(ex)
1054+
constrs = ModelingToolkit.get_constraints(ModelingToolkit.get_constraintsystem(ex))
1055+
@test isequal(costs[1], ex.x + ex.y)
1056+
@test isequal(costs[2], EvalAt(1)(ex.y)^2)
1057+
@test isequal(constrs[1], -3 + EvalAt(0.3)(ex.x) ~ 0)
1058+
@test isequal(constrs[2], -4 + ex.y 0)
1059+
@test ModelingToolkit.get_consolidate(ex)([1, 2]) 1 + log(2)
1060+
end

test/variable_utils.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,29 @@ end
158158
@test !isinitial(c)
159159
@test !isinitial(x)
160160
end
161+
162+
@testset "At" begin
163+
@independent_variables u
164+
@variables x(t) v(..) w(t)[1:3]
165+
@parameters y z(u, t) r[1:3]
166+
167+
@test EvalAt(1)(x) isa Num
168+
@test isequal(EvalAt(1)(y), y)
169+
@test_throws ErrorException EvalAt(1)(z)
170+
@test isequal(EvalAt(1)(v), v(1))
171+
@test isequal(EvalAt(1)(v(t)), v(1))
172+
@test isequal(EvalAt(1)(v(2)), v(2))
173+
174+
arr = EvalAt(1)(w)
175+
var = EvalAt(1)(w[1])
176+
@test arr isa Symbolics.Arr
177+
@test var isa Num
178+
179+
@test isequal(EvalAt(1)(r), r)
180+
@test isequal(EvalAt(1)(r[2]), r[2])
181+
182+
_x = ModelingToolkit.unwrap(x)
183+
@test EvalAt(1)(_x) isa Symbolics.BasicSymbolic
184+
@test only(arguments(EvalAt(1)(_x))) == 1
185+
@test EvalAt(1)(D(x)) isa Num
186+
end

0 commit comments

Comments
 (0)