Skip to content

Commit e0687d6

Browse files
feat: add substitute_component
1 parent 5cfb1b6 commit e0687d6

File tree

4 files changed

+448
-1
lines changed

4 files changed

+448
-1
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ export isinput, isoutput, getbounds, hasbounds, getguess, hasguess, isdisturbanc
288288
hasunit, getunit, hasconnect, getconnect,
289289
hasmisc, getmisc, state_priority
290290
export ode_order_lowering, dae_order_lowering, liouville_transform,
291-
change_independent_variable
291+
change_independent_variable, substitute_component
292292
export PDESystem
293293
export Differential, expand_derivatives, @derivatives
294294
export Equation, ConstrainedEquation

src/systems/abstractsystem.jl

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3142,3 +3142,180 @@ has_diff_eqs(osys21) # returns `false`.
31423142
```
31433143
"""
31443144
has_diff_eqs(sys::AbstractSystem) = any(is_diff_equation, get_eqs(sys))
3145+
3146+
"""
3147+
$(TYPEDSIGNATURES)
3148+
3149+
Validate the rules for replacement of subcomponents as defined in `substitute_component`.
3150+
"""
3151+
function validate_replacement_rule(
3152+
rule::Pair{T, T}; namespace = []) where {T <: AbstractSystem}
3153+
lhs, rhs = rule
3154+
3155+
iscomplete(lhs) && throw(ArgumentError("LHS of replacement rule cannot be completed."))
3156+
iscomplete(rhs) && throw(ArgumentError("RHS of replacement rule cannot be completed."))
3157+
3158+
rhs_h = namespace_hierarchy(nameof(rhs))
3159+
if length(rhs_h) != 1
3160+
throw(ArgumentError("RHS of replacement rule must not be namespaced."))
3161+
end
3162+
rhs_h[1] == namespace_hierarchy(nameof(lhs))[end] ||
3163+
throw(ArgumentError("LHS and RHS must have the same name."))
3164+
3165+
if !isequal(get_iv(lhs), get_iv(rhs))
3166+
throw(ArgumentError("LHS and RHS of replacement rule must have the same independent variable."))
3167+
end
3168+
3169+
lhs_u = get_unknowns(lhs)
3170+
rhs_u = Dict(get_unknowns(rhs) .=> nothing)
3171+
for u in lhs_u
3172+
if !haskey(rhs_u, u)
3173+
if isempty(namespace)
3174+
throw(ArgumentError("RHS of replacement rule does not contain unknown $u."))
3175+
else
3176+
throw(ArgumentError("Subsystem $(join([namespace; nameof(lhs)], NAMESPACE_SEPARATOR)) of RHS does not contain unknown $u."))
3177+
end
3178+
end
3179+
ru = getkey(rhs_u, u, nothing)
3180+
name = join([namespace; nameof(lhs); (hasname(u) ? getname(u) : Symbol(u))],
3181+
NAMESPACE_SEPARATOR)
3182+
l_connect = something(getconnect(u), Equality)
3183+
r_connect = something(getconnect(ru), Equality)
3184+
if l_connect != r_connect
3185+
throw(ArgumentError("Variable $(name) should have connection metadata $(l_connect),"))
3186+
end
3187+
3188+
l_input = isinput(u)
3189+
r_input = isinput(ru)
3190+
if l_input != r_input
3191+
throw(ArgumentError("Variable $name has differing causality. Marked as `input = $l_input` in LHS and `input = $r_input` in RHS."))
3192+
end
3193+
l_output = isoutput(u)
3194+
r_output = isoutput(ru)
3195+
if l_output != r_output
3196+
throw(ArgumentError("Variable $name has differing causality. Marked as `output = $l_output` in LHS and `output = $r_output` in RHS."))
3197+
end
3198+
end
3199+
3200+
lhs_p = get_ps(lhs)
3201+
rhs_p = Set(get_ps(rhs))
3202+
for p in lhs_p
3203+
if !(p in rhs_p)
3204+
if isempty(namespace)
3205+
throw(ArgumentError("RHS of replacement rule does not contain parameter $p"))
3206+
else
3207+
throw(ArgumentError("Subsystem $(join([namespace; nameof(lhs)], NAMESPACE_SEPARATOR)) of RHS does not contain parameter $p."))
3208+
end
3209+
end
3210+
end
3211+
3212+
lhs_s = get_systems(lhs)
3213+
rhs_s = Dict(nameof(s) => s for s in get_systems(rhs))
3214+
3215+
for s in lhs_s
3216+
if haskey(rhs_s, nameof(s))
3217+
rs = rhs_s[nameof(s)]
3218+
if isconnector(s)
3219+
name = join([namespace; nameof(lhs); nameof(s)], NAMESPACE_SEPARATOR)
3220+
if !isconnector(rs)
3221+
throw(ArgumentError("Subsystem $name of RHS is not a connector."))
3222+
end
3223+
if (lct = get_connector_type(s)) !== (rct = get_connector_type(rs))
3224+
throw(ArgumentError("Subsystem $name of RHS has connection type $rct but LHS has $lct."))
3225+
end
3226+
end
3227+
validate_replacement_rule(s => rs; namespace = [namespace; nameof(rhs)])
3228+
continue
3229+
end
3230+
name1 = join([namespace; nameof(lhs)], NAMESPACE_SEPARATOR)
3231+
throw(ArgumentError("$name1 of replacement rule does not contain subsystem $(nameof(s))."))
3232+
end
3233+
end
3234+
3235+
"""
3236+
$(TYPEDSIGNATURES)
3237+
3238+
Chain `getproperty` calls on `root` in the order given in `hierarchy`.
3239+
3240+
# Keyword Arguments
3241+
3242+
- `skip_namespace_first`: Whether to avoid namespacing in the first `getproperty` call.
3243+
"""
3244+
function recursive_getproperty(
3245+
root::AbstractSystem, hierarchy::Vector{Symbol}; skip_namespace_first = true)
3246+
cur = root
3247+
for (i, name) in enumerate(hierarchy)
3248+
cur = getproperty(cur, name; namespace = i > 1 || !skip_namespace_first)
3249+
end
3250+
return cur
3251+
end
3252+
3253+
"""
3254+
$(TYPEDSIGNATURES)
3255+
3256+
Recursively descend through `sys`, finding all connection equations and re-creating them
3257+
using the names of the involved variables/systems and finding the required variables/
3258+
systems in the hierarchy.
3259+
"""
3260+
function recreate_connections(sys::AbstractSystem)
3261+
eqs = map(get_eqs(sys)) do eq
3262+
eq.lhs isa Union{Connection, AnalysisPoint} || return eq
3263+
if eq.lhs isa Connection
3264+
oldargs = get_systems(eq.rhs)
3265+
else
3266+
ap::AnalysisPoint = eq.rhs
3267+
oldargs = [ap.input; ap.outputs]
3268+
end
3269+
newargs = map(get_systems(eq.rhs)) do arg
3270+
name = arg isa AbstractSystem ? nameof(arg) : getname(arg)
3271+
hierarchy = namespace_hierarchy(name)
3272+
return recursive_getproperty(sys, hierarchy)
3273+
end
3274+
if eq.lhs isa Connection
3275+
return eq.lhs ~ Connection(newargs)
3276+
else
3277+
return eq.lhs ~ AnalysisPoint(newargs[1], eq.rhs.name, newargs[2:end])
3278+
end
3279+
end
3280+
@set! sys.eqs = eqs
3281+
@set! sys.systems = map(recreate_connections, get_systems(sys))
3282+
return sys
3283+
end
3284+
3285+
"""
3286+
$(TYPEDSIGNATURES)
3287+
3288+
Given a hierarchical system `sys` and a rule `lhs => rhs`, replace the subsystem `lhs` in
3289+
`sys` by `rhs`. The `lhs` must be the namespaced version of a subsystem of `sys` (e.g.
3290+
obtained via `sys.inner.component`). The `rhs` must be valid as per the following
3291+
conditions:
3292+
3293+
1. `rhs` must not be namespaced.
3294+
2. The name of `rhs` must be the same as the unnamespaced name of `lhs`.
3295+
3. Neither one of `lhs` or `rhs` can be marked as complete.
3296+
4. Both `lhs` and `rhs` must share the same independent variable.
3297+
5. `rhs` must contain at least all of the unknowns and parameters present in
3298+
`lhs`.
3299+
6. Corresponding unknowns in `rhs` must share the same connection and causality
3300+
(input/output) metadata as their counterparts in `lhs`.
3301+
7. For each subsystem of `lhs`, there must be an identically named subsystem of `rhs`.
3302+
These two corresponding subsystems must satisfy conditions 3, 4, 5, 6, 7. If the
3303+
subsystem of `lhs` is a connector, the corresponding subsystem of `rhs` must also
3304+
be a connector of the same type.
3305+
3306+
`sys` also cannot be marked as complete.
3307+
"""
3308+
function substitute_component(sys::T, rule::Pair{T, T}) where {T <: AbstractSystem}
3309+
iscomplete(sys) &&
3310+
throw(ArgumentError("Cannot replace subsystems of completed systems"))
3311+
3312+
validate_replacement_rule(rule)
3313+
3314+
lhs, rhs = rule
3315+
hierarchy = namespace_hierarchy(nameof(lhs))
3316+
3317+
newsys, _ = modify_nested_subsystem(sys, hierarchy) do inner
3318+
return rhs, ()
3319+
end
3320+
return recreate_connections(newsys)
3321+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ end
9797
@safetestset "Analysis Points Test" include("analysis_points.jl")
9898
@safetestset "Causal Variables Connection Test" include("causal_variables_connection.jl")
9999
@safetestset "Debugging Test" include("debugging.jl")
100+
@safetestset "Subsystem replacement" include("substitute_component.jl")
100101
end
101102
end
102103

0 commit comments

Comments
 (0)