Skip to content

Commit 3d7fd3b

Browse files
Merge pull request #3530 from AayushSabharwal/as/remake-copy-u0
fix: update initials with non-symbolic `u0` in `remake`
2 parents 606a043 + 01a43be commit 3d7fd3b

File tree

3 files changed

+44
-6
lines changed

3 files changed

+44
-6
lines changed

Diff for: src/systems/nonlinear/initializesystem.jl

+28-5
Original file line numberDiff line numberDiff line change
@@ -654,17 +654,40 @@ end
654654

655655
function SciMLBase.late_binding_update_u0_p(
656656
prob, sys::AbstractSystem, u0, p, t0, newu0, newp)
657+
supports_initialization(sys) || return newu0, newp
657658
u0 === missing && return newu0, (p === missing ? copy(newp) : newp)
658-
eltype(u0) <: Pair || return newu0, (p === missing ? copy(newp) : newp)
659+
# non-symbolic u0 updates initials...
660+
if !(eltype(u0) <: Pair)
661+
# if `p` is not provided or is symbolic
662+
p === missing || eltype(p) <: Pair || return newu0, newp
663+
newu0 === nothing && return newu0, newp
664+
all(is_parameter(sys, Initial(x)) for x in unknowns(sys)) || return newu0, newp
665+
newp = p === missing ? copy(newp) : newp
666+
initials, repack, alias = SciMLStructures.canonicalize(
667+
SciMLStructures.Initials(), newp)
668+
if eltype(initials) != eltype(newu0)
669+
initials = DiffEqBase.promote_u0(initials, newu0, t0)
670+
newp = repack(initials)
671+
end
672+
if length(newu0) != length(unknowns(sys))
673+
throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(unknowns(sys)))). Got $(typeof(newu0)) of length $(length(newu0))"))
674+
end
675+
setp(sys, Initial.(unknowns(sys)))(newp, newu0)
676+
return newu0, newp
677+
end
659678

660679
newp = p === missing ? copy(newp) : newp
661680
newu0 = DiffEqBase.promote_u0(newu0, newp, t0)
662681
tunables, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp)
663-
tunables = DiffEqBase.promote_u0(tunables, newu0, t0)
664-
newp = repack(tunables)
682+
if eltype(tunables) != eltype(newu0)
683+
tunables = DiffEqBase.promote_u0(tunables, newu0, t0)
684+
newp = repack(tunables)
685+
end
665686
initials, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp)
666-
initials = DiffEqBase.promote_u0(initials, newu0, t0)
667-
newp = repack(initials)
687+
if eltype(initials) != eltype(newu0)
688+
initials = DiffEqBase.promote_u0(initials, newu0, t0)
689+
newp = repack(initials)
690+
end
668691

669692
allsyms = all_symbols(sys)
670693
for (k, v) in u0

Diff for: test/initializationsystem.jl

+15
Original file line numberDiff line numberDiff line change
@@ -1476,3 +1476,18 @@ end
14761476
@test sol.ps[Γ[1]] 5.0
14771477
end
14781478
end
1479+
1480+
@testset "Issue#3504: Update initials when `remake` called with non-symbolic `u0`" begin
1481+
@variables x(t) y(t)
1482+
@parameters c1 c2
1483+
@mtkbuild sys = ODESystem([D(x) ~ -c1 * x + c2 * y, D(y) ~ c1 * x - c2 * y], t)
1484+
prob1 = ODEProblem(sys, [1.0, 2.0], (0.0, 1.0), [c1 => 1.0, c2 => 2.0])
1485+
prob2 = remake(prob1, u0 = [2.0, 3.0])
1486+
prob3 = remake(prob1, u0 = [2.0, 3.0], p = [c1 => 2.0])
1487+
integ1 = init(prob1, Tsit5())
1488+
integ2 = init(prob2, Tsit5())
1489+
integ3 = init(prob3, Tsit5())
1490+
@test integ2.u [2.0, 3.0]
1491+
@test integ3.u [2.0, 3.0]
1492+
@test integ3.ps[c1] 2.0
1493+
end

Diff for: test/nonlinearsystem.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ testdict = Dict([:test => 1])
239239

240240
prob_ = remake(prob, u0 = [1.0, 2.0, 3.0], p = [a => 1.1, b => 1.2, c => 1.3])
241241
@test prob_.u0 == [1.0, 2.0, 3.0]
242-
initials = unknowns(sys) .=> ones(3)
242+
initials = unknowns(sys) .=> [1.0, 2.0, 3.0]
243243
@test prob_.p == MTKParameters(sys, [a => 1.1, b => 1.2, c => 1.3, initials...])
244244

245245
prob_ = remake(prob, u0 = Dict(y => 2.0), p = Dict(a => 2.0))

0 commit comments

Comments
 (0)