Skip to content

Commit c0446c1

Browse files
Merge pull request #3304 from AayushSabharwal/as/connect-causal-variables
feat: add support for causal connections of variables
2 parents f7f0221 + f6728dd commit c0446c1

File tree

7 files changed

+250
-4
lines changed

7 files changed

+250
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ StaticArrays = "0.10, 0.11, 0.12, 1.0"
143143
StochasticDiffEq = "6.72.1"
144144
StochasticDelayDiffEq = "1.8.1"
145145
SymbolicIndexingInterface = "0.3.36"
146-
SymbolicUtils = "3.7"
146+
SymbolicUtils = "3.10"
147147
Symbolics = "6.22.1"
148148
URIs = "1"
149149
UnPack = "0.1, 1.0"

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ include("systems/index_cache.jl")
146146
include("systems/parameter_buffer.jl")
147147
include("systems/abstractsystem.jl")
148148
include("systems/model_parsing.jl")
149-
include("systems/analysis_points.jl")
150149
include("systems/connectors.jl")
150+
include("systems/analysis_points.jl")
151151
include("systems/imperative_affect.jl")
152152
include("systems/callbacks.jl")
153153
include("systems/problem_utils.jl")

src/systems/abstractsystem.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1875,6 +1875,13 @@ Equivalent to `length(equations(expand_connections(sys))) - length(filter(eq ->
18751875
function n_expanded_connection_equations(sys::AbstractSystem)
18761876
# TODO: what about inputs?
18771877
isconnector(sys) && return length(get_unknowns(sys))
1878+
sys = remove_analysis_points(sys)
1879+
n_variable_connect_eqs = 0
1880+
for eq in equations(sys)
1881+
is_causal_variable_connection(eq.rhs) || continue
1882+
n_variable_connect_eqs += length(get_systems(eq.rhs)) - 1
1883+
end
1884+
18781885
sys, (csets, _) = generate_connection_set(sys)
18791886
ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets)
18801887
n_outer_stream_variables = 0
@@ -1897,7 +1904,7 @@ function n_expanded_connection_equations(sys::AbstractSystem)
18971904
# n_toplevel_unused_flows += count(x->get_connection_type(x) === Flow && !(x in toplevel_flows), get_unknowns(m))
18981905
#end
18991906

1900-
nextras = n_outer_stream_variables + length(ceqs)
1907+
nextras = n_outer_stream_variables + length(ceqs) + n_variable_connect_eqs
19011908
end
19021909

19031910
function Base.show(

src/systems/analysis_points.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,14 @@ function Symbolics.connect(in::AbstractSystem, name::Symbol, out, outs...; verbo
208208
return AnalysisPoint() ~ AnalysisPoint(in, name, [out; collect(outs)]; verbose)
209209
end
210210

211+
function Symbolics.connect(
212+
in::ConnectableSymbolicT, name::Symbol, out::ConnectableSymbolicT,
213+
outs::ConnectableSymbolicT...; verbose = true)
214+
allvars = (in, out, outs...)
215+
validate_causal_variables_connection(allvars)
216+
return AnalysisPoint() ~ AnalysisPoint(in, name, [out; collect(outs)]; verbose)
217+
end
218+
211219
"""
212220
$(TYPEDSIGNATURES)
213221
@@ -240,7 +248,7 @@ connection. This is the variable named `u` if present, and otherwise the only
240248
variable in the system. If the system does not have a variable named `u` and
241249
contains multiple variables, throw an error.
242250
"""
243-
function ap_var(sys)
251+
function ap_var(sys::AbstractSystem)
244252
if hasproperty(sys, :u)
245253
return sys.u
246254
end
@@ -249,6 +257,15 @@ function ap_var(sys)
249257
error("Could not determine the analysis-point variable in system $(nameof(sys)). To use an analysis point, apply it to a connection between causal blocks which have a variable named `u` or a single unknown of the same size.")
250258
end
251259

260+
"""
261+
$(TYPEDSIGNATURES)
262+
263+
For an `AnalysisPoint` involving causal variables. Simply return the variable.
264+
"""
265+
function ap_var(var::ConnectableSymbolicT)
266+
return var
267+
end
268+
252269
"""
253270
$(TYPEDEF)
254271

src/systems/connectors.jl

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,93 @@ SymbolicUtils.promote_symtype(::typeof(instream), _) = Real
6868

6969
isconnector(s::AbstractSystem) = has_connector_type(s) && get_connector_type(s) !== nothing
7070

71+
"""
72+
$(TYPEDEF)
73+
74+
Utility struct which wraps a symbolic variable used in a `Connection` to enable `Base.show`
75+
to work.
76+
"""
77+
struct SymbolicWithNameof
78+
var::Any
79+
end
80+
81+
function Base.nameof(x::SymbolicWithNameof)
82+
return Symbol(x.var)
83+
end
84+
85+
is_causal_variable_connection(c) = false
86+
function is_causal_variable_connection(c::Connection)
87+
all(x -> x isa SymbolicWithNameof, get_systems(c))
88+
end
89+
90+
const ConnectableSymbolicT = Union{BasicSymbolic, Num, Symbolics.Arr}
91+
92+
const CAUSAL_CONNECTION_ERR = """
93+
Only causal variables can be used in a `connect` statement. The first argument must \
94+
be a single output variable and all subsequent variables must be input variables.
95+
"""
96+
97+
function VariableNotOutputError(var)
98+
ArgumentError("""
99+
$CAUSAL_CONNECTION_ERR Expected $var to be marked as an output with `[output = true]` \
100+
in the variable metadata.
101+
""")
102+
end
103+
104+
function VariableNotInputError(var)
105+
ArgumentError("""
106+
$CAUSAL_CONNECTION_ERR Expected $var to be marked an input with `[input = true]` \
107+
in the variable metadata.
108+
""")
109+
end
110+
111+
"""
112+
$(TYPEDSIGNATURES)
113+
114+
Perform validation for a connect statement involving causal variables.
115+
"""
116+
function validate_causal_variables_connection(allvars)
117+
var1 = allvars[1]
118+
var2 = allvars[2]
119+
vars = Base.tail(Base.tail(allvars))
120+
for var in allvars
121+
vtype = getvariabletype(var)
122+
vtype === VARIABLE ||
123+
throw(ArgumentError("Expected $var to be of kind `$VARIABLE`. Got `$vtype`."))
124+
end
125+
if length(unique(allvars)) !== length(allvars)
126+
throw(ArgumentError("Expected all connection variables to be unique. Got variables $allvars which contains duplicate entries."))
127+
end
128+
allsizes = map(size, allvars)
129+
if !allequal(allsizes)
130+
throw(ArgumentError("Expected all connection variables to have the same size. Got variables $allvars with sizes $allsizes respectively."))
131+
end
132+
isoutput(var1) || throw(VariableNotOutputError(var1))
133+
isinput(var2) || throw(VariableNotInputError(var2))
134+
for var in vars
135+
isinput(var) || throw(VariableNotInputError(var))
136+
end
137+
end
138+
139+
"""
140+
$(TYPEDSIGNATURES)
141+
142+
Connect multiple causal variables. The first variable must be an output, and all subsequent
143+
variables must be inputs. The statement `connect(var1, var2, var3, ...)` expands to:
144+
145+
```julia
146+
var1 ~ var2
147+
var1 ~ var3
148+
# ...
149+
```
150+
"""
151+
function Symbolics.connect(var1::ConnectableSymbolicT, var2::ConnectableSymbolicT,
152+
vars::ConnectableSymbolicT...)
153+
allvars = (var1, var2, vars...)
154+
validate_causal_variables_connection(allvars)
155+
return Equation(Connection(), Connection(map(SymbolicWithNameof, allvars)))
156+
end
157+
71158
function flowvar(sys::AbstractSystem)
72159
sts = get_unknowns(sys)
73160
for s in sts
@@ -329,6 +416,10 @@ function generate_connection_set!(connectionsets, domain_csets,
329416
for eq in eqs′
330417
lhs = eq.lhs
331418
rhs = eq.rhs
419+
420+
# causal variable connections will be expanded before we get here,
421+
# but this guard is useful for `n_expanded_connection_equations`.
422+
is_causal_variable_connection(rhs) && continue
332423
if find !== nothing && find(rhs, _getname(namespace))
333424
neweq, extra_unknown = replace(rhs, _getname(namespace))
334425
if extra_unknown isa AbstractArray
@@ -479,9 +570,41 @@ function domain_defaults(sys, domain_csets)
479570
def
480571
end
481572

573+
"""
574+
$(TYPEDSIGNATURES)
575+
576+
Recursively descend through the hierarchy of `sys` and expand all connection equations
577+
of causal variables. Return the modified system.
578+
"""
579+
function expand_variable_connections(sys::AbstractSystem)
580+
eqs = copy(get_eqs(sys))
581+
valid_idxs = trues(length(eqs))
582+
additional_eqs = Equation[]
583+
584+
for (i, eq) in enumerate(eqs)
585+
eq.lhs isa Connection || continue
586+
connection = eq.rhs
587+
elements = connection.systems
588+
is_causal_variable_connection(connection) || continue
589+
590+
valid_idxs[i] = false
591+
elements = map(x -> x.var, elements)
592+
outvar = first(elements)
593+
for invar in Iterators.drop(elements, 1)
594+
push!(additional_eqs, outvar ~ invar)
595+
end
596+
end
597+
eqs = [eqs[valid_idxs]; additional_eqs]
598+
subsystems = map(expand_variable_connections, get_systems(sys))
599+
@set! sys.eqs = eqs
600+
@set! sys.systems = subsystems
601+
return sys
602+
end
603+
482604
function expand_connections(sys::AbstractSystem, find = nothing, replace = nothing;
483605
debug = false, tol = 1e-10, scalarize = true)
484606
sys = remove_analysis_points(sys)
607+
sys = expand_variable_connections(sys)
485608
sys, (csets, domain_csets) = generate_connection_set(sys, find, replace; scalarize)
486609
ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets)
487610
_sys = expand_instream(instream_csets, sys; debug = debug, tol = tol)

test/causal_variables_connection.jl

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
using ModelingToolkit, ModelingToolkitStandardLibrary.Blocks
2+
using ModelingToolkit: t_nounits as t, D_nounits as D
3+
4+
@testset "Error checking" begin
5+
@variables begin
6+
x(t)
7+
y(t), [input = true]
8+
z(t), [output = true]
9+
w(t)
10+
v(t), [input = true]
11+
u(t), [output = true]
12+
xarr(t)[1:4], [output = true]
13+
yarr(t)[1:2, 1:2], [input = true]
14+
end
15+
@parameters begin
16+
p, [input = true]
17+
q, [output = true]
18+
end
19+
20+
@test_throws ["p", "kind", "VARIABLE", "PARAMETER"] connect(z, p)
21+
@test_throws ["q", "kind", "VARIABLE", "PARAMETER"] connect(q, y)
22+
@test_throws ["p", "kind", "VARIABLE", "PARAMETER"] connect(z, y, p)
23+
24+
@test_throws ["unique"] connect(z, y, y)
25+
26+
@test_throws ["same size"] connect(xarr, yarr)
27+
28+
@test_throws ["Expected", "x", "output = true", "metadata"] connect(x, y)
29+
@test_throws ["Expected", "y", "output = true", "metadata"] connect(y, v)
30+
31+
@test_throws ["Expected", "x", "input = true", "metadata"] connect(z, x)
32+
@test_throws ["Expected", "x", "input = true", "metadata"] connect(z, y, x)
33+
@test_throws ["Expected", "u", "input = true", "metadata"] connect(z, u)
34+
@test_throws ["Expected", "u", "input = true", "metadata"] connect(z, y, u)
35+
end
36+
37+
@testset "Connection expansion" begin
38+
@named P = FirstOrder(k = 1, T = 1)
39+
@named C = Gain(; k = -1)
40+
41+
eqs = [connect(P.output.u, C.input.u)
42+
connect(C.output.u, P.input.u)]
43+
sys1 = ODESystem(eqs, t, systems = [P, C], name = :hej)
44+
sys = expand_connections(sys1)
45+
@test any(isequal(P.output.u ~ C.input.u), equations(sys))
46+
@test any(isequal(C.output.u ~ P.input.u), equations(sys))
47+
48+
@named sysouter = ODESystem(Equation[], t; systems = [sys1])
49+
sys = expand_connections(sysouter)
50+
@test any(isequal(sys1.P.output.u ~ sys1.C.input.u), equations(sys))
51+
@test any(isequal(sys1.C.output.u ~ sys1.P.input.u), equations(sys))
52+
end
53+
54+
@testset "With Analysis Points" begin
55+
@named P = FirstOrder(k = 1, T = 1)
56+
@named C = Gain(; k = -1)
57+
58+
ap = AnalysisPoint(:plant_input)
59+
eqs = [connect(P.output, C.input), connect(C.output.u, ap, P.input.u)]
60+
sys = ODESystem(eqs, t, systems = [P, C], name = :hej)
61+
@named nested_sys = ODESystem(Equation[], t; systems = [sys])
62+
63+
test_cases = [
64+
("inner", sys, sys.plant_input),
65+
("nested", nested_sys, nested_sys.hej.plant_input),
66+
("inner - Symbol", sys, :plant_input),
67+
("nested - Symbol", nested_sys, nameof(sys.plant_input))
68+
]
69+
70+
@testset "get_sensitivity - $name" for (name, sys, ap) in test_cases
71+
matrices, _ = get_sensitivity(sys, ap)
72+
@test matrices.A[] == -2
73+
@test matrices.B[] * matrices.C[] == -1 # either one negative
74+
@test matrices.D[] == 1
75+
end
76+
77+
@testset "get_comp_sensitivity - $name" for (name, sys, ap) in test_cases
78+
matrices, _ = get_comp_sensitivity(sys, ap)
79+
@test matrices.A[] == -2
80+
@test matrices.B[] * matrices.C[] == 1 # both positive or negative
81+
@test matrices.D[] == 0
82+
end
83+
84+
@testset "get_looptransfer - $name" for (name, sys, ap) in test_cases
85+
matrices, _ = get_looptransfer(sys, ap)
86+
@test matrices.A[] == -1
87+
@test matrices.B[] * matrices.C[] == -1 # either one negative
88+
@test matrices.D[] == 0
89+
end
90+
91+
@testset "open_loop - $name" for (name, sys, ap) in test_cases
92+
open_sys, (du, u) = open_loop(sys, ap)
93+
matrices, _ = linearize(open_sys, [du], [u])
94+
@test matrices.A[] == -1
95+
@test matrices.B[] * matrices.C[] == -1 # either one negative
96+
@test matrices.D[] == 0
97+
end
98+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ end
8585
@safetestset "Constraints Test" include("constraints.jl")
8686
@safetestset "IfLifting Test" include("if_lifting.jl")
8787
@safetestset "Analysis Points Test" include("analysis_points.jl")
88+
@safetestset "Causal Variables Connection Test" include("causal_variables_connection.jl")
8889
end
8990
end
9091

0 commit comments

Comments
 (0)