Skip to content

Commit 3a54d9e

Browse files
Merge pull request #3350 from vyudu/validate_parammap
fix: validate u0map and pmap
2 parents 3762673 + 18e766e commit 3a54d9e

File tree

3 files changed

+76
-3
lines changed

3 files changed

+76
-3
lines changed

src/systems/parameter_buffer.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,7 @@ end
721721

722722
function Base.showerror(io::IO, e::MissingParametersError)
723723
println(io, MISSING_PARAMETERS_MESSAGE)
724-
println(io, e.vars)
724+
println(io, join(e.vars, ", "))
725725
end
726726

727727
function InvalidParameterSizeException(param, val)

src/systems/problem_utils.jl

+41-2
Original file line numberDiff line numberDiff line change
@@ -752,12 +752,14 @@ function process_SciMLProblem(
752752

753753
u0Type = typeof(u0map)
754754
pType = typeof(pmap)
755-
_u0map = u0map
755+
756756
u0map = to_varmap(u0map, dvs)
757757
symbols_to_symbolics!(sys, u0map)
758-
_pmap = pmap
759758
pmap = to_varmap(pmap, parameters(sys))
760759
symbols_to_symbolics!(sys, pmap)
760+
761+
check_inputmap_keys(sys, u0map, pmap)
762+
761763
defs = add_toterms(recursive_unwrap(defaults(sys)))
762764
cmap, cs = get_cmap(sys)
763765
kwargs = NamedTuple(kwargs)
@@ -854,6 +856,43 @@ function process_SciMLProblem(
854856
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
855857
end
856858

859+
# Check that the keys of a u0map or pmap are valid
860+
# (i.e. are symbolic keys, and are defined for the system.)
861+
function check_inputmap_keys(sys, u0map, pmap)
862+
badvarkeys = Any[]
863+
for k in keys(u0map)
864+
if symbolic_type(k) === NotSymbolic()
865+
push!(badvarkeys, k)
866+
end
867+
end
868+
869+
badparamkeys = Any[]
870+
for k in keys(pmap)
871+
if symbolic_type(k) === NotSymbolic()
872+
push!(badparamkeys, k)
873+
end
874+
end
875+
(isempty(badvarkeys) && isempty(badparamkeys)) || throw(InvalidKeyError(collect(badvarkeys), collect(badparamkeys)))
876+
end
877+
878+
const BAD_KEY_MESSAGE = """
879+
Undefined keys found in the parameter or initial condition maps. Check if symbolic variable names have been reassigned.
880+
The following keys are invalid:
881+
"""
882+
883+
struct InvalidKeyError <: Exception
884+
vars::Any
885+
params::Any
886+
end
887+
888+
function Base.showerror(io::IO, e::InvalidKeyError)
889+
println(io, BAD_KEY_MESSAGE)
890+
println(io, "u0map: $(join(e.vars, ", "))")
891+
println(io, "pmap: $(join(e.params, ", "))")
892+
end
893+
894+
895+
857896
##############
858897
# Legacy functions for backward compatibility
859898
##############

test/problem_validation.jl

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using ModelingToolkit
2+
using ModelingToolkit: t_nounits as t, D_nounits as D
3+
4+
@testset "Input map validation" begin
5+
import ModelingToolkit: InvalidKeyError, MissingParametersError
6+
@variables X(t)
7+
@parameters p d
8+
eqs = [D(X) ~ p - d*X]
9+
@mtkbuild osys = ODESystem(eqs, t)
10+
11+
p = "I accidentally renamed p"
12+
u0 = [X => 1.0]
13+
ps = [p => 1.0, d => 0.5]
14+
@test_throws MissingParametersError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
15+
16+
@parameters p d
17+
ps = [p => 1.0, d => 0.5, "Random stuff" => 3.0]
18+
@test_throws InvalidKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
19+
20+
u0 = [:X => 1.0, "random" => 3.0]
21+
@test_throws InvalidKeyError oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)
22+
23+
@variables x(t) y(t) z(t)
24+
@parameters a b c d
25+
eqs = [D(x) ~ x*a, D(y) ~ y*c, D(z) ~ b + d]
26+
@mtkbuild sys = ODESystem(eqs, t)
27+
pmap = [a => 1, b => 2, c => 3, d => 4, "b" => 2]
28+
u0map = [x => 1, y => 2, z => 3]
29+
@test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap)
30+
31+
pmap = [a => 1, b => 2, c => 3, d => 4]
32+
u0map = [x => 1, y => 2, z => 3, :0 => 3]
33+
@test_throws InvalidKeyError ODEProblem(sys, u0map, (0., 1.), pmap)
34+
end

0 commit comments

Comments
 (0)