@@ -247,25 +247,6 @@ function recursive_unwrap(x::AbstractDict)
247
247
return anydict (unwrap (k) => recursive_unwrap (v) for (k, v) in x)
248
248
end
249
249
250
- """
251
- $(TYPEDSIGNATURES)
252
-
253
- Return the appropriate zero value for a symbolic variable representing a number or array of
254
- numbers. Sized array symbolics return a zero-filled array of matching size. Unsized array
255
- symbolics return an empty array of the appropriate `eltype`.
256
- """
257
- function zero_var (x:: Symbolic{T} ) where {V <: Number , T <: Union{V, AbstractArray{V}} }
258
- if Symbolics. isarraysymbolic (x)
259
- if is_sized_array_symbolic (x)
260
- return zeros (eltype (T), size (x))
261
- else
262
- return T[]
263
- end
264
- else
265
- return zero (T)
266
- end
267
- end
268
-
269
250
"""
270
251
$(TYPEDSIGNATURES)
271
252
@@ -362,7 +343,7 @@ Keyword arguments:
362
343
- `is_initializeprob, guesses`: Used to determine whether the system is missing guesses.
363
344
"""
364
345
function better_varmap_to_vars (varmap:: AbstractDict , vars:: Vector ;
365
- tofloat = true , container_type = Array,
346
+ tofloat = true , container_type = Array, floatT = Nothing,
366
347
toterm = default_toterm, promotetoconcrete = nothing , check = true ,
367
348
allow_symbolic = false , is_initializeprob = false )
368
349
isempty (vars) && return nothing
@@ -385,6 +366,9 @@ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector;
385
366
is_initializeprob ? throw (MissingGuessError (missingsyms, missingvals)) :
386
367
throw (UnexpectedSymbolicValueInVarmap (missingsyms[1 ], missingvals[1 ]))
387
368
end
369
+ if tofloat && ! (floatT == Nothing)
370
+ vals = floatT .(vals)
371
+ end
388
372
end
389
373
390
374
if container_type <: Union{AbstractDict, Tuple, Nothing, SciMLBase.NullParameters}
@@ -533,12 +517,12 @@ function (f::UpdateInitializeprob)(initializeprob, prob)
533
517
f. setvals (initializeprob, f. getvals (prob))
534
518
end
535
519
536
- function get_temporary_value (p)
520
+ function get_temporary_value (p, floatT = Float64 )
537
521
stype = symtype (unwrap (p))
538
522
return if stype == Real
539
- zero (Float64 )
523
+ zero (floatT )
540
524
elseif stype <: AbstractArray{Real}
541
- zeros (Float64 , size (p))
525
+ zeros (floatT , size (p))
542
526
elseif stype <: Real
543
527
zero (stype)
544
528
elseif stype <: AbstractArray
@@ -648,15 +632,32 @@ All other keyword arguments are forwarded to `InitializationProblem`.
648
632
"""
649
633
function maybe_build_initialization_problem (
650
634
sys:: AbstractSystem , op:: AbstractDict , u0map, pmap, t, defs,
651
- guesses, missing_unknowns; implicit_dae = false , u0_constructor = identity, kwargs... )
635
+ guesses, missing_unknowns; implicit_dae = false ,
636
+ u0_constructor = identity, floatT = Float64, kwargs... )
652
637
guesses = merge (ModelingToolkit. guesses (sys), todict (guesses))
653
638
654
639
if t === nothing && is_time_dependent (sys)
655
- t = 0.0
640
+ t = zero (floatT)
656
641
end
657
642
658
643
initializeprob = ModelingToolkit. InitializationProblem {true, SciMLBase.FullSpecialize} (
659
644
sys, t, u0map, pmap; guesses, kwargs... )
645
+ if state_values (initializeprob) != = nothing
646
+ initializeprob = remake (initializeprob; u0 = floatT .(state_values (initializeprob)))
647
+ end
648
+ initp = parameter_values (initializeprob)
649
+ if is_split (sys)
650
+ buffer, repack, _ = SciMLStructures. canonicalize (SciMLStructures. Tunable (), initp)
651
+ initp = repack (floatT .(buffer))
652
+ buffer, repack, _ = SciMLStructures. canonicalize (SciMLStructures. Initials (), initp)
653
+ initp = repack (floatT .(buffer))
654
+ elseif initp isa AbstractArray
655
+ initp′ = similar (initp, floatT)
656
+ copyto! (initp′, initp)
657
+ initp = initp′
658
+ end
659
+ initializeprob = remake (initializeprob; p = initp)
660
+
660
661
meta = get_metadata (initializeprob. f. sys)
661
662
662
663
if is_time_dependent (sys)
@@ -692,7 +693,7 @@ function maybe_build_initialization_problem(
692
693
get (op, p, missing ) === missing || continue
693
694
p = unwrap (p)
694
695
stype = symtype (p)
695
- op[p] = get_temporary_value (p)
696
+ op[p] = get_temporary_value (p, floatT )
696
697
if iscall (p) && operation (p) === getindex
697
698
arrp = arguments (p)[1 ]
698
699
op[arrp] = collect (arrp)
@@ -701,7 +702,7 @@ function maybe_build_initialization_problem(
701
702
702
703
if is_time_dependent (sys)
703
704
for v in missing_unknowns
704
- op[v] = zero_var (v )
705
+ op[v] = get_temporary_value (v, floatT )
705
706
end
706
707
empty! (missing_unknowns)
707
708
end
@@ -712,6 +713,26 @@ function maybe_build_initialization_problem(
712
713
initializeprobpmap))
713
714
end
714
715
716
+ """
717
+ $(TYPEDSIGNATURES)
718
+
719
+ Calculate the floating point type to use from the given `varmap` by looking at variables
720
+ with a constant value.
721
+ """
722
+ function float_type_from_varmap (varmap, floatT = Bool)
723
+ for (k, v) in varmap
724
+ symbolic_type (v) == NotSymbolic () || continue
725
+ is_array_of_symbolics (v) && continue
726
+
727
+ if v isa AbstractArray
728
+ floatT = promote_type (floatT, eltype (v))
729
+ elseif v isa Real
730
+ floatT = promote_type (floatT, typeof (v))
731
+ end
732
+ end
733
+ return float (floatT)
734
+ end
735
+
715
736
"""
716
737
$(TYPEDSIGNATURES)
717
738
@@ -815,12 +836,19 @@ function process_SciMLProblem(
815
836
op, missing_unknowns, missing_pars = build_operating_point! (sys,
816
837
u0map, pmap, defs, cmap, dvs, ps)
817
838
839
+ floatT = Bool
840
+ if u0Type <: AbstractArray && eltype (u0Type) <: Real
841
+ floatT = float (eltype (u0Type))
842
+ else
843
+ floatT = float_type_from_varmap (op, floatT)
844
+ end
845
+
818
846
if ! is_time_dependent (sys) || is_initializesystem (sys)
819
847
add_observed_equations! (u0map, obs)
820
848
end
821
849
if u0_constructor === identity && u0Type <: StaticArray
822
850
u0_constructor = vals -> SymbolicUtils. Code. create_array (
823
- u0Type, eltype (vals) , Val (1 ), Val (length (vals)), vals... )
851
+ u0Type, floatT , Val (1 ), Val (length (vals)), vals... )
824
852
end
825
853
if build_initializeprob
826
854
kws = maybe_build_initialization_problem (
@@ -830,7 +858,7 @@ function process_SciMLProblem(
830
858
warn_cyclic_dependency, check_units = check_initialization_units,
831
859
circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc,
832
860
force_time_independent = force_initialization_time_independent, algebraic_only, allow_incomplete,
833
- u0_constructor)
861
+ u0_constructor, floatT )
834
862
835
863
kwargs = merge (kwargs, kws)
836
864
end
@@ -858,7 +886,7 @@ function process_SciMLProblem(
858
886
evaluate_varmap! (op, dvs; limit = substitution_limit)
859
887
860
888
u0 = better_varmap_to_vars (
861
- op, dvs; tofloat,
889
+ op, dvs; tofloat, floatT,
862
890
container_type = u0Type, allow_symbolic = symbolic_u0, is_initializeprob)
863
891
864
892
if u0 != = nothing
@@ -882,7 +910,7 @@ function process_SciMLProblem(
882
910
end
883
911
evaluate_varmap! (op, ps; limit = substitution_limit)
884
912
if is_split (sys)
885
- p = MTKParameters (sys, op)
913
+ p = MTKParameters (sys, op; floatT = floatT )
886
914
else
887
915
p = better_varmap_to_vars (op, ps; tofloat, container_type = pType)
888
916
end
@@ -898,6 +926,16 @@ function process_SciMLProblem(
898
926
du0 = nothing
899
927
end
900
928
929
+ if build_initializeprob
930
+ t0 = t
931
+ if is_time_dependent (sys) && t0 === nothing
932
+ t0 = zero (floatT)
933
+ end
934
+ initialization_data = SciMLBase. remake_initialization_data (
935
+ kwargs. initialization_data, kwargs, u0, t0, p, u0, p)
936
+ kwargs = merge (kwargs,)
937
+ end
938
+
901
939
f = constructor (sys, dvs, ps, u0; p = p,
902
940
eval_expression = eval_expression,
903
941
eval_module = eval_module,
0 commit comments