Skip to content

fix: update initials with non-symbolic u0 in remake #3530

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 3, 2025
30 changes: 25 additions & 5 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -654,17 +654,37 @@ end

function SciMLBase.late_binding_update_u0_p(
prob, sys::AbstractSystem, u0, p, t0, newu0, newp)
supports_initialization(sys) || return newu0, newp
u0 === missing && return newu0, (p === missing ? copy(newp) : newp)
eltype(u0) <: Pair || return newu0, (p === missing ? copy(newp) : newp)
# non-symbolic u0 updates initials...
if !(eltype(u0) <: Pair)
# if `p` is not provided or is symbolic
p === missing || eltype(p) <: Pair || return newu0, newp
newu0 === nothing && return newu0, newp
all(is_parameter(sys, Initial(x)) for x in unknowns(sys)) || return newu0, newp
newp = p === missing ? copy(newp) : newp
initials, repack, alias = SciMLStructures.canonicalize(
SciMLStructures.Initials(), newp)
if eltype(initials) != eltype(newu0)
initials = DiffEqBase.promote_u0(initials, newu0, t0)
newp = repack(initials)
end
setp(sys, Initial.(unknowns(sys)))(newp, newu0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this validate that the size is correct? It seems like that should get a contextualized error message.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'll throw a BroadcastError just because of how SII works but year, an error message is a good idea.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll let tests run. If they pass, I'll push the error message and merge.

return newu0, newp
end

newp = p === missing ? copy(newp) : newp
newu0 = DiffEqBase.promote_u0(newu0, newp, t0)
tunables, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp)
tunables = DiffEqBase.promote_u0(tunables, newu0, t0)
newp = repack(tunables)
if eltype(tunables) != eltype(newu0)
tunables = DiffEqBase.promote_u0(tunables, newu0, t0)
newp = repack(tunables)
end
initials, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp)
initials = DiffEqBase.promote_u0(initials, newu0, t0)
newp = repack(initials)
if eltype(initials) != eltype(newu0)
initials = DiffEqBase.promote_u0(initials, newu0, t0)
newp = repack(initials)
end

allsyms = all_symbols(sys)
for (k, v) in u0
Expand Down
15 changes: 15 additions & 0 deletions test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1476,3 +1476,18 @@ end
@test sol.ps[Γ[1]] ≈ 5.0
end
end

@testset "Issue#3504: Update initials when `remake` called with non-symbolic `u0`" begin
@variables x(t) y(t)
@parameters c1 c2
@mtkbuild sys = ODESystem([D(x) ~ -c1 * x + c2 * y, D(y) ~ c1 * x - c2 * y], t)
prob1 = ODEProblem(sys, [1.0, 2.0], (0.0, 1.0), [c1 => 1.0, c2 => 2.0])
prob2 = remake(prob1, u0 = [2.0, 3.0])
prob3 = remake(prob1, u0 = [2.0, 3.0], p = [c1 => 2.0])
integ1 = init(prob1, Tsit5())
integ2 = init(prob2, Tsit5())
integ3 = init(prob3, Tsit5())
@test integ2.u ≈ [2.0, 3.0]
@test integ3.u ≈ [2.0, 3.0]
@test integ3.ps[c1] ≈ 2.0
end
2 changes: 1 addition & 1 deletion test/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ testdict = Dict([:test => 1])

prob_ = remake(prob, u0 = [1.0, 2.0, 3.0], p = [a => 1.1, b => 1.2, c => 1.3])
@test prob_.u0 == [1.0, 2.0, 3.0]
initials = unknowns(sys) .=> ones(3)
initials = unknowns(sys) .=> [1.0, 2.0, 3.0]
@test prob_.p == MTKParameters(sys, [a => 1.1, b => 1.2, c => 1.3, initials...])

prob_ = remake(prob, u0 = Dict(y => 2.0), p = Dict(a => 2.0))
Expand Down
Loading