Skip to content

feat: add costs, constraints, coalesce to @mtkmodel, introduce At operator #3584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ function var_derivative_graph! end
include("bipartite_graph.jl")
using .BipartiteGraphs

export EvalAt
include("variables.jl")
include("parameters.jl")
include("independent_variables.jl")
Expand Down
52 changes: 24 additions & 28 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ end
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
controls = Num[],
observed = Equation[],
constraintsystem = nothing,
constraints = Any[],
costs = Num[],
consolidate = nothing,
systems = ODESystem[],
Expand Down Expand Up @@ -276,11 +276,29 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
name === nothing &&
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."

constraintsystem = nothing
if !isempty(constraints)
constraintsystem = process_constraint_system(constraints, dvs, ps, iv)
for p in parameters(constraintsystem)
!in(p, Set(ps)) && push!(ps, p)
end
end

if !isempty(costs)
coststs, costps = process_costs(costs, dvs, ps, iv)
for p in costps
!in(p, Set(ps)) && push!(ps, p)
end
end
costs = wrap.(costs)

iv′ = value(iv)
ps′ = value.(ps)
ctrl′ = value.(controls)
dvs′ = value.(dvs)
dvs′ = filter(x -> !isdelay(x, iv), dvs′)

parameter_dependencies, ps′ = process_parameter_dependencies(
parameter_dependencies, ps′)
if !(isempty(default_u0) && isempty(default_p))
Expand Down Expand Up @@ -350,7 +368,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
metadata, gui_metadata, is_dde, tstops, checks = checks)
end

function ODESystem(eqs, iv; constraints = Equation[], costs = Num[], kwargs...)
function ODESystem(eqs, iv; kwargs...)
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)

for eq in get(kwargs, :parameter_dependencies, Equation[])
Expand Down Expand Up @@ -382,30 +400,8 @@ function ODESystem(eqs, iv; constraints = Equation[], costs = Num[], kwargs...)
end
algevars = setdiff(allunknowns, diffvars)

consvars = OrderedSet()
constraintsystem = nothing
if !isempty(constraints)
constraintsystem = process_constraint_system(constraints, allunknowns, new_ps, iv)
for st in get_unknowns(constraintsystem)
iscall(st) ?
!in(operation(st)(iv), allunknowns) && push!(consvars, st) :
!in(st, allunknowns) && push!(consvars, st)
end
for p in parameters(constraintsystem)
!in(p, new_ps) && push!(new_ps, p)
end
end

if !isempty(costs)
coststs, costps = process_costs(costs, allunknowns, new_ps, iv)
for p in costps
!in(p, new_ps) && push!(new_ps, p)
end
end
costs = wrap.(costs)

return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars, consvars))),
collect(new_ps); constraintsystem, costs, kwargs...)
return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars))),
collect(new_ps); kwargs...)
end

# NOTE: equality does not check cached Jacobian
Expand Down Expand Up @@ -760,7 +756,7 @@ end
Build the constraint system for the ODESystem.
"""
function process_constraint_system(
constraints::Vector{Equation}, sts, ps, iv; consname = :cons)
constraints::Vector, sts, ps, iv; consname = :cons)
isempty(constraints) && return nothing

constraintsts = OrderedSet()
Expand Down Expand Up @@ -800,7 +796,7 @@ Return the set of additional parameters found in the system, e.g. in x(p) ~ 3 th
parameter of the system.
"""
function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv)
sts = sysvars
sts = Set(sysvars)

for var in auxvars
if !iscall(var)
Expand Down
4 changes: 2 additions & 2 deletions src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ struct DiscreteSystem <: AbstractDiscreteSystem
tearing_state = nothing, substitutions = nothing, namespacing = true,
complete = false, index_cache = nothing, parent = nothing,
isscheduled = false;
checks::Union{Bool, Int} = true)
checks::Union{Bool, Int} = true, kwargs...)
if checks == true || (checks & CheckComponents) > 0
check_independent_variables([iv])
check_variables(dvs, iv)
Expand Down Expand Up @@ -199,7 +199,7 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
DiscreteSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
eqs, iv′, dvs′, ps′, tspan, var_to_name, observed, name, description, systems,
defaults, guesses, initializesystem, initialization_eqs, preface, connector_type,
parameter_dependencies, metadata, gui_metadata, kwargs...)
parameter_dependencies, metadata, gui_metadata)
end

function DiscreteSystem(eqs, iv; kwargs...)
Expand Down
42 changes: 39 additions & 3 deletions src/systems/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector)
ps, sps, vs, = [], [], []
c_evts = []
d_evts = []
cons = []
costs = []
kwargs = OrderedCollections.OrderedSet()
where_types = Union{Symbol, Expr}[]

Expand All @@ -80,7 +82,7 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector)
for arg in expr.args
if arg.head == :macrocall
parse_model!(exprs.args, comps, ext, eqs, icon, vs, ps,
sps, c_evts, d_evts, dict, mod, arg, kwargs, where_types)
sps, c_evts, d_evts, cons, costs, dict, mod, arg, kwargs, where_types)
elseif arg.head == :block
push!(exprs.args, arg)
elseif arg.head == :if
Expand Down Expand Up @@ -120,13 +122,15 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector)
gui_metadata = isassigned(icon) > 0 ? GUIMetadata(GlobalRef(mod, name), icon[]) :
GUIMetadata(GlobalRef(mod, name))

consolidate = get(dict, :consolidate, nothing)
description = get(dict, :description, "")

@inline pop_structure_dict!.(
Ref(dict), [:constants, :defaults, :kwargs, :structural_parameters])

sys = :($type($(flatten_equations)(equations), $iv, variables, parameters;
name, description = $description, systems, gui_metadata = $gui_metadata, defaults))
name, description = $description, systems, gui_metadata = $gui_metadata, defaults,
costs = [$(costs...)], constraints = [$(cons...)], consolidate = $consolidate))

if length(ext) == 0
push!(exprs.args, :(var"#___sys___" = $sys))
Expand Down Expand Up @@ -610,7 +614,7 @@ function get_var(mod::Module, b)
end

function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps, c_evts, d_evts,
dict, mod, arg, kwargs, where_types)
cons, costs, dict, mod, arg, kwargs, where_types)
mname = arg.args[1]
body = arg.args[end]
if mname == Symbol("@description")
Expand Down Expand Up @@ -638,6 +642,12 @@ function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps, c_evts, d_evts,
parse_icon!(body, dict, icon, mod)
elseif mname == Symbol("@defaults")
parse_system_defaults!(exprs, arg, dict)
elseif mname == Symbol("@constraints")
parse_costs!(cons, dict, body)
elseif mname == Symbol("@costs")
parse_constraints!(costs, dict, body)
elseif mname == Symbol("@consolidate")
parse_consolidate!(body, dict)
else
error("$mname is not handled.")
end
Expand Down Expand Up @@ -1149,6 +1159,32 @@ function parse_discrete_events!(d_evts, dict, body)
end
end

function parse_constraints!(cons, dict, body)
dict[:constraints] = []
Base.remove_linenums!(body)
for arg in body.args
push!(cons, arg)
push!(dict[:constraints], readable_code.(cons)...)
end
end

function parse_costs!(costs, dict, body)
dict[:costs] = []
Base.remove_linenums!(body)
for arg in body.args
push!(costs, arg)
push!(dict[:costs], readable_code.(costs)...)
end
end

function parse_consolidate!(body, dict)
if !(occursin("->", string(body)) || occursin("=", string(body)))
error("Consolidate must be a function definition.")
else
dict[:consolidate] = body
end
end

function parse_icon!(body::String, dict, icon, mod)
icon_dir = get(ENV, "MTK_ICONS_DIR", joinpath(DEPOT_PATH[1], "mtk_icons"))
dict[:icon] = icon[] = if isfile(body)
Expand Down
2 changes: 1 addition & 1 deletion src/systems/unit_check.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ function validate(jumps::ArrayPartition{<:Union{Any, Vector{<:JumpType}}}, t::Sy
all([validate(jumps.x[idx], t, info = labels[idx]) for idx in 1:3])
end

function validate(eq::Equation; info::String = "")
function validate(eq::Union{Inequality, Equation}; info::String = "")
if typeof(eq.lhs) == Connection
_validate(eq.rhs; info)
else
Expand Down
40 changes: 40 additions & 0 deletions src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -612,3 +612,43 @@ getunshifted(x::Symbolic) = Symbolics.getmetadata(x, VariableUnshifted, nothing)

getshift(x::Num) = getshift(unwrap(x))
getshift(x::Symbolic) = Symbolics.getmetadata(x, VariableShift, 0)

###################
### Evaluate at ###
###################
struct EvalAt <: Symbolics.Operator
t::Union{Symbolic, Number}
end

function (A::EvalAt)(x::Symbolic)
if symbolic_type(x) == NotSymbolic() || !iscall(x)
if x isa Symbolics.CallWithMetadata
return x(A.t)
else
return x
end
end

if iscall(x) && operation(x) == getindex
arr = arguments(x)[1]
term(getindex, A(arr), arguments(x)[2:end]...)
elseif operation(x) isa Differential
x = default_toterm(x)
A(x)
else
length(arguments(x)) !== 1 &&
error("Variable $x has too many arguments. EvalAt can only be applied to one-argument variables.")
(symbolic_type(only(arguments(x))) !== ScalarSymbolic()) && return x
return operation(x)(A.t)
end
end

function (A::EvalAt)(x::Union{Num, Symbolics.Arr})
wrap(A(unwrap(x)))
end
SymbolicUtils.isbinop(::EvalAt) = false

Base.nameof(::EvalAt) = :EvalAt
Base.show(io::IO, A::EvalAt) = print(io, "EvalAt(", A.t, ")")
Base.:(==)(A1::EvalAt, A2::EvalAt) = isequal(A1.t, A2.t)
Base.hash(A::EvalAt, u::UInt) = hash(A.t, u)
32 changes: 32 additions & 0 deletions test/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1026,3 +1026,35 @@ end
@named sys = Float2Bool()
@test typeof(sys) == DiscreteSystem
end

@testset "Constraints, costs, consolidate" begin
@mtkmodel Example begin
@variables begin
x(t)
y(t)
end
@equations begin
x ~ y
end
@constraints begin
EvalAt(0.3)(x) ~ 3
y ≲ 4
end
@costs begin
x + y
EvalAt(1)(y)^2
end
@consolidate f(u) = u[1]^2 + log(u[2])
end

@named ex = Example()
ex = complete(ex)

costs = ModelingToolkit.get_costs(ex)
constrs = ModelingToolkit.get_constraints(ModelingToolkit.get_constraintsystem(ex))
@test isequal(costs[1], ex.x + ex.y)
@test isequal(costs[2], EvalAt(1)(ex.y)^2)
@test isequal(constrs[1], -3 + EvalAt(0.3)(ex.x) ~ 0)
@test isequal(constrs[2], -4 + ex.y ≲ 0)
@test ModelingToolkit.get_consolidate(ex)([1, 2]) ≈ 1 + log(2)
end
26 changes: 26 additions & 0 deletions test/variable_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,29 @@ end
@test !isinitial(c)
@test !isinitial(x)
end

@testset "At" begin
@independent_variables u
@variables x(t) v(..) w(t)[1:3]
@parameters y z(u, t) r[1:3]

@test EvalAt(1)(x) isa Num
@test isequal(EvalAt(1)(y), y)
@test_throws ErrorException EvalAt(1)(z)
@test isequal(EvalAt(1)(v), v(1))
@test isequal(EvalAt(1)(v(t)), v(1))
@test isequal(EvalAt(1)(v(2)), v(2))

arr = EvalAt(1)(w)
var = EvalAt(1)(w[1])
@test arr isa Symbolics.Arr
@test var isa Num

@test isequal(EvalAt(1)(r), r)
@test isequal(EvalAt(1)(r[2]), r[2])

_x = ModelingToolkit.unwrap(x)
@test EvalAt(1)(_x) isa Symbolics.BasicSymbolic
@test only(arguments(EvalAt(1)(_x))) == 1
@test EvalAt(1)(D(x)) isa Num
end
Loading