@@ -654,17 +654,40 @@ end
654
654
655
655
function SciMLBase. late_binding_update_u0_p (
656
656
prob, sys:: AbstractSystem , u0, p, t0, newu0, newp)
657
+ supports_initialization (sys) || return newu0, newp
657
658
u0 === missing && return newu0, (p === missing ? copy (newp) : newp)
658
- eltype (u0) <: Pair || return newu0, (p === missing ? copy (newp) : newp)
659
+ # non-symbolic u0 updates initials...
660
+ if ! (eltype (u0) <: Pair )
661
+ # if `p` is not provided or is symbolic
662
+ p === missing || eltype (p) <: Pair || return newu0, newp
663
+ newu0 === nothing && return newu0, newp
664
+ all (is_parameter (sys, Initial (x)) for x in unknowns (sys)) || return newu0, newp
665
+ newp = p === missing ? copy (newp) : newp
666
+ initials, repack, alias = SciMLStructures. canonicalize (
667
+ SciMLStructures. Initials (), newp)
668
+ if eltype (initials) != eltype (newu0)
669
+ initials = DiffEqBase. promote_u0 (initials, newu0, t0)
670
+ newp = repack (initials)
671
+ end
672
+ if length (newu0) != length (unknowns (sys))
673
+ throw (ArgumentError (" Expected `newu0` to be of same length as unknowns ($(length (unknowns (sys))) ). Got $(typeof (newu0)) of length $(length (newu0)) " ))
674
+ end
675
+ setp (sys, Initial .(unknowns (sys)))(newp, newu0)
676
+ return newu0, newp
677
+ end
659
678
660
679
newp = p === missing ? copy (newp) : newp
661
680
newu0 = DiffEqBase. promote_u0 (newu0, newp, t0)
662
681
tunables, repack, alias = SciMLStructures. canonicalize (SciMLStructures. Tunable (), newp)
663
- tunables = DiffEqBase. promote_u0 (tunables, newu0, t0)
664
- newp = repack (tunables)
682
+ if eltype (tunables) != eltype (newu0)
683
+ tunables = DiffEqBase. promote_u0 (tunables, newu0, t0)
684
+ newp = repack (tunables)
685
+ end
665
686
initials, repack, alias = SciMLStructures. canonicalize (SciMLStructures. Initials (), newp)
666
- initials = DiffEqBase. promote_u0 (initials, newu0, t0)
667
- newp = repack (initials)
687
+ if eltype (initials) != eltype (newu0)
688
+ initials = DiffEqBase. promote_u0 (initials, newu0, t0)
689
+ newp = repack (initials)
690
+ end
668
691
669
692
allsyms = all_symbols (sys)
670
693
for (k, v) in u0
0 commit comments