Skip to content

Commit e574fb3

Browse files
Merge pull request #3561 from AayushSabharwal/as/initial-promote
fix: fix values being promoted to `Float64` in problem construction
2 parents 19f8164 + cd19c7e commit e574fb3

File tree

6 files changed

+99
-65
lines changed

6 files changed

+99
-65
lines changed

src/systems/abstractsystem.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -733,11 +733,11 @@ function add_initialization_parameters(sys::AbstractSystem)
733733
defs = copy(get_defaults(sys))
734734
for ivar in initials
735735
if symbolic_type(ivar) == ScalarSymbolic()
736-
defs[ivar] = zero_var(ivar)
736+
defs[ivar] = false
737737
else
738738
defs[ivar] = collect(ivar)
739739
for scal_ivar in defs[ivar]
740-
defs[scal_ivar] = zero_var(scal_ivar)
740+
defs[scal_ivar] = false
741741
end
742742
end
743743
end

src/systems/diffeqs/abstractodesystem.jl

-26
Original file line numberDiff line numberDiff line change
@@ -1541,32 +1541,6 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
15411541
filter_missing_values!(parammap)
15421542
u0map = merge(ModelingToolkit.guesses(sys), todict(guesses), u0map)
15431543

1544-
fullmap = merge(u0map, parammap)
1545-
u0T = Union{}
1546-
for sym in unknowns(isys)
1547-
val = fixpoint_sub(sym, fullmap)
1548-
symbolic_type(val) == NotSymbolic() || continue
1549-
u0T = promote_type(u0T, typeof(val))
1550-
end
1551-
for eq in observed(isys)
1552-
# ignore HACK-ed observed equations
1553-
symbolic_type(eq.lhs) == ArraySymbolic() && continue
1554-
val = fixpoint_sub(eq.lhs, fullmap)
1555-
symbolic_type(val) == NotSymbolic() || continue
1556-
u0T = promote_type(u0T, typeof(val))
1557-
end
1558-
if u0T != Union{}
1559-
u0T = eltype(u0T)
1560-
u0map = Dict(k => if v === nothing
1561-
nothing
1562-
elseif symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)
1563-
v isa AbstractArray ? u0T.(v) : u0T(v)
1564-
else
1565-
v
1566-
end
1567-
for (k, v) in u0map)
1568-
end
1569-
15701544
TProb = if neqs == nunknown && isempty(unassigned_vars)
15711545
if use_scc && neqs > 0
15721546
if is_split(isys)

src/systems/nonlinear/initializesystem.jl

+5-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ function generate_initializesystem(sys::AbstractTimeDependentSystem;
88
pmap = Dict(),
99
initialization_eqs = [],
1010
guesses = Dict(),
11-
default_dd_guess = 0.0,
11+
default_dd_guess = Bool(0),
1212
algebraic_only = false,
1313
check_units = true, check_defguess = false,
1414
name = nameof(sys), extra_metadata = (;), kwargs...)
@@ -646,10 +646,12 @@ function SciMLBase.remake_initialization_data(
646646

647647
op, missing_unknowns, missing_pars = build_operating_point!(sys,
648648
u0map, pmap, defs, cmap, dvs, ps)
649+
floatT = float_type_from_varmap(op)
649650
kws = maybe_build_initialization_problem(
650651
sys, op, u0map, pmap, t0, defs, guesses, missing_unknowns;
651-
use_scc, initialization_eqs, allow_incomplete = true)
652-
return get(kws, :initialization_data, nothing)
652+
use_scc, initialization_eqs, floatT, allow_incomplete = true)
653+
654+
return SciMLBase.remake_initialization_data(sys, kws, newu0, t0, newp, newu0, newp)
653655
end
654656

655657
function SciMLBase.late_binding_update_u0_p(

src/systems/parameter_buffer.jl

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
symconvert(::Type{Symbolics.Struct{T}}, x) where {T} = convert(T, x)
2-
symconvert(::Type{T}, x) where {T} = convert(T, x)
3-
symconvert(::Type{Real}, x::Integer) = convert(Float64, x)
2+
symconvert(::Type{T}, x::V) where {T, V} = convert(promote_type(T, V), x)
3+
symconvert(::Type{Real}, x::Integer) = convert(Float16, x)
44
symconvert(::Type{V}, x) where {V <: AbstractArray} = convert(V, symconvert.(eltype(V), x))
55

66
struct MTKParameters{T, I, D, C, N, H}
@@ -28,7 +28,7 @@ the default behavior).
2828
"""
2929
function MTKParameters(
3030
sys::AbstractSystem, p, u0 = Dict(); tofloat = false,
31-
t0 = nothing, substitution_limit = 1000)
31+
t0 = nothing, substitution_limit = 1000, floatT = nothing)
3232
ic = if has_index_cache(sys) && get_index_cache(sys) !== nothing
3333
get_index_cache(sys)
3434
else
@@ -56,6 +56,10 @@ function MTKParameters(
5656
op[get_iv(sys)] = t0
5757
end
5858

59+
if floatT === nothing
60+
floatT = float(float_type_from_varmap(op))
61+
end
62+
5963
isempty(missing_pars) || throw(MissingParametersError(collect(missing_pars)))
6064
evaluate_varmap!(op, ps; limit = substitution_limit)
6165

@@ -111,6 +115,9 @@ function MTKParameters(
111115
if ctype <: FnType
112116
ctype = fntype_to_function_type(ctype)
113117
end
118+
if ctype == Real && floatT !== nothing
119+
ctype = floatT
120+
end
114121
val = symconvert(ctype, val)
115122
done = set_value(sym, val)
116123
if !done && Symbolics.isarraysymbolic(sym)

src/systems/problem_utils.jl

+69-31
Original file line numberDiff line numberDiff line change
@@ -247,25 +247,6 @@ function recursive_unwrap(x::AbstractDict)
247247
return anydict(unwrap(k) => recursive_unwrap(v) for (k, v) in x)
248248
end
249249

250-
"""
251-
$(TYPEDSIGNATURES)
252-
253-
Return the appropriate zero value for a symbolic variable representing a number or array of
254-
numbers. Sized array symbolics return a zero-filled array of matching size. Unsized array
255-
symbolics return an empty array of the appropriate `eltype`.
256-
"""
257-
function zero_var(x::Symbolic{T}) where {V <: Number, T <: Union{V, AbstractArray{V}}}
258-
if Symbolics.isarraysymbolic(x)
259-
if is_sized_array_symbolic(x)
260-
return zeros(eltype(T), size(x))
261-
else
262-
return T[]
263-
end
264-
else
265-
return zero(T)
266-
end
267-
end
268-
269250
"""
270251
$(TYPEDSIGNATURES)
271252
@@ -362,7 +343,7 @@ Keyword arguments:
362343
- `is_initializeprob, guesses`: Used to determine whether the system is missing guesses.
363344
"""
364345
function better_varmap_to_vars(varmap::AbstractDict, vars::Vector;
365-
tofloat = true, container_type = Array,
346+
tofloat = true, container_type = Array, floatT = Nothing,
366347
toterm = default_toterm, promotetoconcrete = nothing, check = true,
367348
allow_symbolic = false, is_initializeprob = false)
368349
isempty(vars) && return nothing
@@ -385,6 +366,9 @@ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector;
385366
is_initializeprob ? throw(MissingGuessError(missingsyms, missingvals)) :
386367
throw(UnexpectedSymbolicValueInVarmap(missingsyms[1], missingvals[1]))
387368
end
369+
if tofloat && !(floatT == Nothing)
370+
vals = floatT.(vals)
371+
end
388372
end
389373

390374
if container_type <: Union{AbstractDict, Tuple, Nothing, SciMLBase.NullParameters}
@@ -533,12 +517,12 @@ function (f::UpdateInitializeprob)(initializeprob, prob)
533517
f.setvals(initializeprob, f.getvals(prob))
534518
end
535519

536-
function get_temporary_value(p)
520+
function get_temporary_value(p, floatT = Float64)
537521
stype = symtype(unwrap(p))
538522
return if stype == Real
539-
zero(Float64)
523+
zero(floatT)
540524
elseif stype <: AbstractArray{Real}
541-
zeros(Float64, size(p))
525+
zeros(floatT, size(p))
542526
elseif stype <: Real
543527
zero(stype)
544528
elseif stype <: AbstractArray
@@ -648,15 +632,32 @@ All other keyword arguments are forwarded to `InitializationProblem`.
648632
"""
649633
function maybe_build_initialization_problem(
650634
sys::AbstractSystem, op::AbstractDict, u0map, pmap, t, defs,
651-
guesses, missing_unknowns; implicit_dae = false, u0_constructor = identity, kwargs...)
635+
guesses, missing_unknowns; implicit_dae = false,
636+
u0_constructor = identity, floatT = Float64, kwargs...)
652637
guesses = merge(ModelingToolkit.guesses(sys), todict(guesses))
653638

654639
if t === nothing && is_time_dependent(sys)
655-
t = 0.0
640+
t = zero(floatT)
656641
end
657642

658643
initializeprob = ModelingToolkit.InitializationProblem{true, SciMLBase.FullSpecialize}(
659644
sys, t, u0map, pmap; guesses, kwargs...)
645+
if state_values(initializeprob) !== nothing
646+
initializeprob = remake(initializeprob; u0 = floatT.(state_values(initializeprob)))
647+
end
648+
initp = parameter_values(initializeprob)
649+
if is_split(sys)
650+
buffer, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), initp)
651+
initp = repack(floatT.(buffer))
652+
buffer, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Initials(), initp)
653+
initp = repack(floatT.(buffer))
654+
elseif initp isa AbstractArray
655+
initp′ = similar(initp, floatT)
656+
copyto!(initp′, initp)
657+
initp = initp′
658+
end
659+
initializeprob = remake(initializeprob; p = initp)
660+
660661
meta = get_metadata(initializeprob.f.sys)
661662

662663
if is_time_dependent(sys)
@@ -692,7 +693,7 @@ function maybe_build_initialization_problem(
692693
get(op, p, missing) === missing || continue
693694
p = unwrap(p)
694695
stype = symtype(p)
695-
op[p] = get_temporary_value(p)
696+
op[p] = get_temporary_value(p, floatT)
696697
if iscall(p) && operation(p) === getindex
697698
arrp = arguments(p)[1]
698699
op[arrp] = collect(arrp)
@@ -701,7 +702,7 @@ function maybe_build_initialization_problem(
701702

702703
if is_time_dependent(sys)
703704
for v in missing_unknowns
704-
op[v] = zero_var(v)
705+
op[v] = get_temporary_value(v, floatT)
705706
end
706707
empty!(missing_unknowns)
707708
end
@@ -712,6 +713,26 @@ function maybe_build_initialization_problem(
712713
initializeprobpmap))
713714
end
714715

716+
"""
717+
$(TYPEDSIGNATURES)
718+
719+
Calculate the floating point type to use from the given `varmap` by looking at variables
720+
with a constant value.
721+
"""
722+
function float_type_from_varmap(varmap, floatT = Bool)
723+
for (k, v) in varmap
724+
symbolic_type(v) == NotSymbolic() || continue
725+
is_array_of_symbolics(v) && continue
726+
727+
if v isa AbstractArray
728+
floatT = promote_type(floatT, eltype(v))
729+
elseif v isa Real
730+
floatT = promote_type(floatT, typeof(v))
731+
end
732+
end
733+
return float(floatT)
734+
end
735+
715736
"""
716737
$(TYPEDSIGNATURES)
717738
@@ -815,12 +836,19 @@ function process_SciMLProblem(
815836
op, missing_unknowns, missing_pars = build_operating_point!(sys,
816837
u0map, pmap, defs, cmap, dvs, ps)
817838

839+
floatT = Bool
840+
if u0Type <: AbstractArray && eltype(u0Type) <: Real
841+
floatT = float(eltype(u0Type))
842+
else
843+
floatT = float_type_from_varmap(op, floatT)
844+
end
845+
818846
if !is_time_dependent(sys) || is_initializesystem(sys)
819847
add_observed_equations!(u0map, obs)
820848
end
821849
if u0_constructor === identity && u0Type <: StaticArray
822850
u0_constructor = vals -> SymbolicUtils.Code.create_array(
823-
u0Type, eltype(vals), Val(1), Val(length(vals)), vals...)
851+
u0Type, floatT, Val(1), Val(length(vals)), vals...)
824852
end
825853
if build_initializeprob
826854
kws = maybe_build_initialization_problem(
@@ -830,7 +858,7 @@ function process_SciMLProblem(
830858
warn_cyclic_dependency, check_units = check_initialization_units,
831859
circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc,
832860
force_time_independent = force_initialization_time_independent, algebraic_only, allow_incomplete,
833-
u0_constructor)
861+
u0_constructor, floatT)
834862

835863
kwargs = merge(kwargs, kws)
836864
end
@@ -858,7 +886,7 @@ function process_SciMLProblem(
858886
evaluate_varmap!(op, dvs; limit = substitution_limit)
859887

860888
u0 = better_varmap_to_vars(
861-
op, dvs; tofloat,
889+
op, dvs; tofloat, floatT,
862890
container_type = u0Type, allow_symbolic = symbolic_u0, is_initializeprob)
863891

864892
if u0 !== nothing
@@ -882,7 +910,7 @@ function process_SciMLProblem(
882910
end
883911
evaluate_varmap!(op, ps; limit = substitution_limit)
884912
if is_split(sys)
885-
p = MTKParameters(sys, op)
913+
p = MTKParameters(sys, op; floatT = floatT)
886914
else
887915
p = better_varmap_to_vars(op, ps; tofloat, container_type = pType)
888916
end
@@ -898,6 +926,16 @@ function process_SciMLProblem(
898926
du0 = nothing
899927
end
900928

929+
if build_initializeprob
930+
t0 = t
931+
if is_time_dependent(sys) && t0 === nothing
932+
t0 = zero(floatT)
933+
end
934+
initialization_data = SciMLBase.remake_initialization_data(
935+
kwargs.initialization_data, kwargs, u0, t0, p, u0, p)
936+
kwargs = merge(kwargs,)
937+
end
938+
901939
f = constructor(sys, dvs, ps, u0; p = p,
902940
eval_expression = eval_expression,
903941
eval_module = eval_module,

test/initial_values.jl

+13
Original file line numberDiff line numberDiff line change
@@ -252,3 +252,16 @@ end
252252
ps = [p => [4.0, 5.0]]
253253
@test_nowarn NonlinearProblem(nlsys, u0, ps)
254254
end
255+
256+
@testset "Issue#3553: Retain `Float32` initial values" begin
257+
@parameters p d
258+
@variables X(t)
259+
eqs = [D(X) ~ p - d * X]
260+
@mtkbuild osys = ODESystem(eqs, t)
261+
u0 = [X => 1.0f0]
262+
ps = [p => 1.0f0, d => 2.0f0]
263+
oprob = ODEProblem(osys, u0, (0.0f0, 1.0f0), ps)
264+
sol = solve(oprob)
265+
@test eltype(oprob.u0) == Float32
266+
@test eltype(eltype(sol.u)) == Float32
267+
end

0 commit comments

Comments
 (0)