Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate types of state variables in ODESystem/SDESystem construction #3340

Merged
merged 17 commits into from
Jan 31, 2025
47 changes: 7 additions & 40 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,53 +296,21 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
end

function ODESystem(eqs, iv; kwargs...)
eqs = collect(eqs)
# NOTE: this assumes that the order of algebraic equations doesn't matter
diffvars = OrderedSet()
allunknowns = OrderedSet()
ps = OrderedSet()
# reorder equations such that it is in the form of `diffeq, algeeq`
diffeq = Equation[]
algeeq = Equation[]
# initial loop for finding `iv`
if iv === nothing
for eq in eqs
if !(eq.lhs isa Number) # assume eq.lhs is either Differential or Number
iv = iv_from_nested_derivative(eq.lhs)
break
end
end
end
iv = value(iv)
iv === nothing && throw(ArgumentError("Please pass in independent variables."))
compressed_eqs = Equation[] # equations that need to be expanded later, like `connect(a, b)`
for eq in eqs
eq.lhs isa Union{Symbolic, Number} || (push!(compressed_eqs, eq); continue)
collect_vars!(allunknowns, ps, eq, iv)
if isdiffeq(eq)
diffvar, _ = var_from_nested_derivative(eq.lhs)
if check_scope_depth(getmetadata(diffvar, SymScope, LocalScope()), 0)
isequal(iv, iv_from_nested_derivative(eq.lhs)) ||
throw(ArgumentError("An ODESystem can only have one independent variable."))
diffvar in diffvars &&
throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
push!(diffvars, diffvar)
end
push!(diffeq, eq)
else
push!(algeeq, eq)
end
end
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)

for eq in get(kwargs, :parameter_dependencies, Equation[])
collect_vars!(allunknowns, ps, eq, iv)
end

for ssys in get(kwargs, :systems, ODESystem[])
collect_scoped_vars!(allunknowns, ps, ssys, iv)
end

for v in allunknowns
isdelay(v, iv) || continue
collect_vars!(allunknowns, ps, arguments(v)[1], iv)
end

new_ps = OrderedSet()
for p in ps
if iscall(p) && operation(p) === getindex
Expand All @@ -358,9 +326,8 @@ function ODESystem(eqs, iv; kwargs...)
end
end
algevars = setdiff(allunknowns, diffvars)
# the orders here are very important!
return ODESystem(Equation[diffeq; algeeq; compressed_eqs], iv,
collect(Iterators.flatten((diffvars, algevars))), collect(new_ps); kwargs...)

return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars))), collect(new_ps); kwargs...)
end

# NOTE: equality does not check cached Jacobian
Expand Down
46 changes: 46 additions & 0 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,52 @@ function SDESystem(sys::ODESystem, neqs; kwargs...)
SDESystem(equations(sys), neqs, get_iv(sys), unknowns(sys), parameters(sys); kwargs...)
end

function SDESystem(eqs::Vector{Equation}, noiseeqs::AbstractArray, iv; kwargs...)
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)

for eq in get(kwargs, :parameter_dependencies, Equation[])
collect_vars!(allunknowns, ps, eq, iv)
end

for ssys in get(kwargs, :systems, ODESystem[])
collect_scoped_vars!(allunknowns, ps, ssys, iv)
end

for v in allunknowns
isdelay(v, iv) || continue
collect_vars!(allunknowns, ps, arguments(v)[1], iv)
end

new_ps = OrderedSet()
for p in ps
if iscall(p) && operation(p) === getindex
par = arguments(p)[begin]
if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() &&
all(par[i] in ps for i in eachindex(par))
push!(new_ps, par)
else
push!(new_ps, p)
end
else
push!(new_ps, p)
end
end

# validate noise equations
noisedvs = OrderedSet()
noiseps = OrderedSet()
collect_vars!(noisedvs, noiseps, noiseeqs, iv)
for dv in noisedvs
dv ∈ allunknowns || throw(ArgumentError("Variable $dv in noise equations is not an unknown of the system."))
end
algevars = setdiff(allunknowns, diffvars)

return SDESystem(eqs, noiseeqs, iv, Iterators.flatten((diffvars, algevars)), [ps; collect(noiseps)]; kwargs...)
end

SDESystem(eq::Equation, noiseeqs::AbstractArray, args...; kwargs...) = SDESystem([eq], noiseeqs, args...; kwargs...)
SDESystem(eq::Equation, noiseeq, args...; kwargs...) = SDESystem([eq], [noiseeq], args...; kwargs...)

function Base.:(==)(sys1::SDESystem, sys2::SDESystem)
sys1 === sys2 && return true
iv1 = get_iv(sys1)
Expand Down
51 changes: 51 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1185,3 +1185,54 @@ function guesses_from_metadata!(guesses, vars)
guesses[vars[i]] = varguesses[i]
end
end

"""
$(TYPEDSIGNATURES)

Find all the unknowns and parameters from the equations of a SDESystem or ODESystem. Return re-ordered equations, differential variables, all variables, and parameters.
"""
function process_equations(eqs, iv)
eqs = collect(eqs)

diffvars = OrderedSet()
allunknowns = OrderedSet()
ps = OrderedSet()

# NOTE: this assumes that the order of algebraic equations doesn't matter
# reorder equations such that it is in the form of `diffeq, algeeq`
diffeq = Equation[]
algeeq = Equation[]
# initial loop for finding `iv`
if iv === nothing
for eq in eqs
if !(eq.lhs isa Number) # assume eq.lhs is either Differential or Number
iv = iv_from_nested_derivative(eq.lhs)
break
end
end
end
iv = value(iv)
iv === nothing && throw(ArgumentError("Please pass in independent variables."))

compressed_eqs = Equation[] # equations that need to be expanded later, like `connect(a, b)`
for eq in eqs
eq.lhs isa Union{Symbolic, Number} || (push!(compressed_eqs, eq); continue)
collect_vars!(allunknowns, ps, eq, iv)
if isdiffeq(eq)
diffvar, _ = var_from_nested_derivative(eq.lhs)
if check_scope_depth(getmetadata(diffvar, SymScope, LocalScope()), 0)
isequal(iv, iv_from_nested_derivative(eq.lhs)) ||
throw(ArgumentError("An ODESystem can only have one independent variable."))
diffvar in diffvars &&
throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
!(symtype(diffvar) === Real || eltype(symtype(diffvar)) === Real) && throw(ArgumentError("Differential variable $diffvar has type $(symtype(diffvar)). Differential variables should not be concretely typed."))
push!(diffvars, diffvar)
end
push!(diffeq, eq)
else
push!(algeeq, eq)
end
end

diffvars, allunknowns, ps, Equation[diffeq; algeeq; compressed_eqs]
end
10 changes: 10 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,16 @@ end
@test all(x -> any(isapprox(x, atol = 1e-6), sol2.t), expected_tstops)
end

@testset "Validate input types" begin
@parameters p d
@variables X(t)::Int64
eq = D(X) ~ p - d*X
@test_throws ArgumentError @mtkbuild osys = ODESystem([eq], t)
@variables Y(t)[1:3]::String
eq = D(Y) ~ [p, p, p]
@test_throws ArgumentError @mtkbuild osys = ODESystem([eq], t)
end

# Test `isequal`
@testset "`isequal`" begin
@variables X(t)
Expand Down
13 changes: 12 additions & 1 deletion test/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,17 @@ end
@test length(observed(sys)) == 1
end

# Test validating types of states
@testset "Validate input types" begin
@parameters p d
@variables X(t)::Int64
@brownian z
eq2 = D(X) ~ p - d*X + z
@test_throws ArgumentError @mtkbuild ssys = System([eq2], t)
noiseeq = [1]
@test_throws ArgumentError @named ssys = SDESystem([eq2], [noiseeq], t)
end

@testset "SDEFunctionExpr" begin
@parameters σ ρ β
@variables x(tt) y(tt) z(tt)
Expand Down Expand Up @@ -953,4 +964,4 @@ end
@test_throws ErrorException("SDESystem constructed by defining Brownian variables with @brownian must be simplified by calling `structural_simplify` before a SDEProblem can be constructed.") SDEProblem(de, u0map, (0.0, 100.0), parammap)
de = structural_simplify(de)
@test SDEProblem(de, u0map, (0.0, 100.0), parammap) isa SDEProblem
end
end
Loading