Skip to content
47 changes: 7 additions & 40 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,53 +296,21 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
end

function ODESystem(eqs, iv; kwargs...)
eqs = collect(eqs)
# NOTE: this assumes that the order of algebraic equations doesn't matter
diffvars = OrderedSet()
allunknowns = OrderedSet()
ps = OrderedSet()
# reorder equations such that it is in the form of `diffeq, algeeq`
diffeq = Equation[]
algeeq = Equation[]
# initial loop for finding `iv`
if iv === nothing
for eq in eqs
if !(eq.lhs isa Number) # assume eq.lhs is either Differential or Number
iv = iv_from_nested_derivative(eq.lhs)
break
end
end
end
iv = value(iv)
iv === nothing && throw(ArgumentError("Please pass in independent variables."))
compressed_eqs = Equation[] # equations that need to be expanded later, like `connect(a, b)`
for eq in eqs
eq.lhs isa Union{Symbolic, Number} || (push!(compressed_eqs, eq); continue)
collect_vars!(allunknowns, ps, eq, iv)
if isdiffeq(eq)
diffvar, _ = var_from_nested_derivative(eq.lhs)
if check_scope_depth(getmetadata(diffvar, SymScope, LocalScope()), 0)
isequal(iv, iv_from_nested_derivative(eq.lhs)) ||
throw(ArgumentError("An ODESystem can only have one independent variable."))
diffvar in diffvars &&
throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
push!(diffvars, diffvar)
end
push!(diffeq, eq)
else
push!(algeeq, eq)
end
end
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)

for eq in get(kwargs, :parameter_dependencies, Equation[])
collect_vars!(allunknowns, ps, eq, iv)
end

for ssys in get(kwargs, :systems, ODESystem[])
collect_scoped_vars!(allunknowns, ps, ssys, iv)
end

for v in allunknowns
isdelay(v, iv) || continue
collect_vars!(allunknowns, ps, arguments(v)[1], iv)
end

new_ps = OrderedSet()
for p in ps
if iscall(p) && operation(p) === getindex
Expand All @@ -358,9 +326,8 @@ function ODESystem(eqs, iv; kwargs...)
end
end
algevars = setdiff(allunknowns, diffvars)
# the orders here are very important!
return ODESystem(Equation[diffeq; algeeq; compressed_eqs], iv,
collect(Iterators.flatten((diffvars, algevars))), collect(new_ps); kwargs...)

return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars))), collect(new_ps); kwargs...)
end

# NOTE: equality does not check cached Jacobian
Expand Down
46 changes: 46 additions & 0 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,52 @@ function SDESystem(sys::ODESystem, neqs; kwargs...)
SDESystem(equations(sys), neqs, get_iv(sys), unknowns(sys), parameters(sys); kwargs...)
end

function SDESystem(eqs::Vector{Equation}, noiseeqs::AbstractArray, iv; kwargs...)
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)

for eq in get(kwargs, :parameter_dependencies, Equation[])
collect_vars!(allunknowns, ps, eq, iv)
end

for ssys in get(kwargs, :systems, ODESystem[])
collect_scoped_vars!(allunknowns, ps, ssys, iv)
end

for v in allunknowns
isdelay(v, iv) || continue
collect_vars!(allunknowns, ps, arguments(v)[1], iv)
end

new_ps = OrderedSet()
for p in ps
if iscall(p) && operation(p) === getindex
par = arguments(p)[begin]
if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() &&
all(par[i] in ps for i in eachindex(par))
push!(new_ps, par)
else
push!(new_ps, p)
end
else
push!(new_ps, p)
end
end

# validate noise equations
noisedvs = OrderedSet()
noiseps = OrderedSet()
collect_vars!(noisedvs, noiseps, noiseeqs, iv)
for dv in noisedvs
dv ∈ allunknowns || throw(ArgumentError("Variable $dv in noise equations is not an unknown of the system."))
end
algevars = setdiff(allunknowns, diffvars)

return SDESystem(eqs, noiseeqs, iv, Iterators.flatten((diffvars, algevars)), [ps; collect(noiseps)]; kwargs...)
end

SDESystem(eq::Equation, noiseeqs::AbstractArray, args...; kwargs...) = SDESystem([eq], noiseeqs, args...; kwargs...)
SDESystem(eq::Equation, noiseeq, args...; kwargs...) = SDESystem([eq], [noiseeq], args...; kwargs...)

function Base.:(==)(sys1::SDESystem, sys2::SDESystem)
sys1 === sys2 && return true
iv1 = get_iv(sys1)
Expand Down
51 changes: 51 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1185,3 +1185,54 @@ function guesses_from_metadata!(guesses, vars)
guesses[vars[i]] = varguesses[i]
end
end

"""
$(TYPEDSIGNATURES)

Find all the unknowns and parameters from the equations of a SDESystem or ODESystem. Return re-ordered equations, differential variables, all variables, and parameters.
"""
function process_equations(eqs, iv)
eqs = collect(eqs)

diffvars = OrderedSet()
allunknowns = OrderedSet()
ps = OrderedSet()

# NOTE: this assumes that the order of algebraic equations doesn't matter
# reorder equations such that it is in the form of `diffeq, algeeq`
diffeq = Equation[]
algeeq = Equation[]
# initial loop for finding `iv`
if iv === nothing
for eq in eqs
if !(eq.lhs isa Number) # assume eq.lhs is either Differential or Number
iv = iv_from_nested_derivative(eq.lhs)
break
end
end
end
iv = value(iv)
iv === nothing && throw(ArgumentError("Please pass in independent variables."))

compressed_eqs = Equation[] # equations that need to be expanded later, like `connect(a, b)`
for eq in eqs
eq.lhs isa Union{Symbolic, Number} || (push!(compressed_eqs, eq); continue)
collect_vars!(allunknowns, ps, eq, iv)
if isdiffeq(eq)
diffvar, _ = var_from_nested_derivative(eq.lhs)
if check_scope_depth(getmetadata(diffvar, SymScope, LocalScope()), 0)
isequal(iv, iv_from_nested_derivative(eq.lhs)) ||
throw(ArgumentError("An ODESystem can only have one independent variable."))
diffvar in diffvars &&
throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
!(symtype(diffvar) === Real || eltype(symtype(diffvar)) === Real) && throw(ArgumentError("Differential variable $diffvar has type $(symtype(diffvar)). Differential variables should not be concretely typed."))
push!(diffvars, diffvar)
end
push!(diffeq, eq)
else
push!(algeeq, eq)
end
end

diffvars, allunknowns, ps, Equation[diffeq; algeeq; compressed_eqs]
end
10 changes: 10 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,16 @@ end
@test all(x -> any(isapprox(x, atol = 1e-6), sol2.t), expected_tstops)
end

@testset "Validate input types" begin
@parameters p d
@variables X(t)::Int64
eq = D(X) ~ p - d*X
@test_throws ArgumentError @mtkbuild osys = ODESystem([eq], t)
@variables Y(t)[1:3]::String
eq = D(Y) ~ [p, p, p]
@test_throws ArgumentError @mtkbuild osys = ODESystem([eq], t)
end

# Test `isequal`
@testset "`isequal`" begin
@variables X(t)
Expand Down
13 changes: 12 additions & 1 deletion test/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,17 @@ end
@test length(observed(sys)) == 1
end

# Test validating types of states
@testset "Validate input types" begin
@parameters p d
@variables X(t)::Int64
@brownian z
eq2 = D(X) ~ p - d*X + z
@test_throws ArgumentError @mtkbuild ssys = System([eq2], t)
noiseeq = [1]
@test_throws ArgumentError @named ssys = SDESystem([eq2], [noiseeq], t)
end

@testset "SDEFunctionExpr" begin
@parameters σ ρ β
@variables x(tt) y(tt) z(tt)
Expand Down Expand Up @@ -953,4 +964,4 @@ end
@test_throws ErrorException("SDESystem constructed by defining Brownian variables with @brownian must be simplified by calling `structural_simplify` before a SDEProblem can be constructed.") SDEProblem(de, u0map, (0.0, 100.0), parammap)
de = structural_simplify(de)
@test SDEProblem(de, u0map, (0.0, 100.0), parammap) isa SDEProblem
end
end
Loading