Skip to content

Commit 6d91e32

Browse files
Merge pull request #926 from AayushSabharwal/as/late-binding-remake
feat: add `late_binding_update_u0_p`
2 parents 575392e + 34268b5 commit 6d91e32

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

src/remake.jl

+28-1
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,34 @@ function updated_u0_p(
10751075
return (u0 === missing ? state_values(prob) : u0),
10761076
(p === missing ? parameter_values(prob) : p)
10771077
end
1078-
return _updated_u0_p_internal(prob, u0, p, t0; interpret_symbolicmap, use_defaults)
1078+
newu0, newp = _updated_u0_p_internal(prob, u0, p, t0; interpret_symbolicmap, use_defaults)
1079+
return late_binding_update_u0_p(prob, u0, p, t0, newu0, newp)
1080+
end
1081+
1082+
"""
1083+
$(TYPEDSIGNATURES)
1084+
1085+
A function to perform custom modifications to `newu0` and/or `newp` after they have been
1086+
constructed in `remake`. `root_indp` is the innermost index provider found by recursively
1087+
calling `SymbolicIndexingInterface.symbolic_container`, provided for dispatch. Returns
1088+
the updated `newu0` and `newp`.
1089+
"""
1090+
function late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp)
1091+
return newu0, newp
1092+
end
1093+
1094+
"""
1095+
$(TYPEDSIGNATURES)
1096+
1097+
Calls `late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp)` after finding
1098+
`root_indp`.
1099+
"""
1100+
function late_binding_update_u0_p(prob, u0, p, t0, newu0, newp)
1101+
root_indp = prob
1102+
while hasmethod(symbolic_container, Tuple{typeof(root_indp)}) && (sc = symbolic_container(root_indp)) !== root_indp
1103+
root_indp = sc
1104+
end
1105+
return late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp)
10791106
end
10801107

10811108
# overloaded in MTK to intercept symbolic remake

test/remake_tests.jl

+12
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ for T in containerTypes
8383
push!(probs, NonlinearLeastSquaresProblem(fn, u0, T(p)))
8484
end
8585

86+
# temporary definition to test this functionality
87+
function SciMLBase.late_binding_update_u0_p(prob, u0, p::SciMLBase.NullParameters, t0, newu0, newp)
88+
return newu0, ones(3)
89+
end
90+
8691
for prob in deepcopy(probs)
8792
prob2 = @inferred remake(prob)
8893
@test prob2.u0 == u0
@@ -274,8 +279,15 @@ for prob in deepcopy(probs)
274279
end
275280
ForwardDiff.derivative(fakeloss!, 1.0)
276281
end
282+
283+
# test late_binding_update_u0_p
284+
prob2 = remake(prob; p = SciMLBase.NullParameters())
285+
@test prob2.p ones(3)
277286
end
278287

288+
# delete the method defined here to prevent breaking other tests
289+
Base.delete_method(only(methods(SciMLBase.late_binding_update_u0_p, @__MODULE__)))
290+
279291
# eltype(()) <: Pair, so ensure that this doesn't error
280292
function lorenz!(du, u, _, t)
281293
du[1] = 1 * (u[2] - u[1])

0 commit comments

Comments
 (0)