Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix initialization in linearization_function #3348

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ using Compat
using AbstractTrees
using DiffEqBase, SciMLBase, ForwardDiff
using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap, TimeDomain,
PeriodicClock, Clock, SolverStepClock, Continuous
PeriodicClock, Clock, SolverStepClock, Continuous, OverrideInit, NoInit
using Distributed
import JuliaFormatter
using MLStyle
Expand Down
141 changes: 60 additions & 81 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2377,6 +2377,9 @@ See also [`linearize`](@ref) which provides a higher-level interface.
function linearization_function(sys::AbstractSystem, inputs,
outputs; simplify = false,
initialize = true,
initializealg = nothing,
initialization_abstol = 1e-6,
initialization_reltol = 1e-3,
op = Dict(),
p = DiffEqBase.NullParameters(),
zero_dummy_der = false,
Expand All @@ -2403,88 +2406,29 @@ function linearization_function(sys::AbstractSystem, inputs,
op = merge(defs, op)
end
sys = ssys
u0map = Dict(k => v for (k, v) in op if is_variable(ssys, k))
initsys = structural_simplify(
generate_initializesystem(
sys, u0map = u0map, guesses = guesses(sys), algebraic_only = true),
fully_determined = false)

# HACK: some unknowns may not be involved in any initialization equations, and are
# thus removed from the system during `structural_simplify`.
# This causes `getu(initsys, unknowns(sys))` to fail, so we add them back as parameters
# for now.
missing_unknowns = setdiff(unknowns(sys), all_symbols(initsys))
if !isempty(missing_unknowns)
if warn_initialize_determined
@warn "Initialization system is underdetermined. No equations for $(missing_unknowns). Initialization will default to using least squares. To suppress this warning pass warn_initialize_determined = false."
end
new_parameters = [parameters(initsys); missing_unknowns]
@set! initsys.ps = new_parameters
initsys = complete(initsys)
end

if p isa SciMLBase.NullParameters
p = Dict()
else
p = todict(p)
end
x0 = merge(defaults_and_guesses(sys), op)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
sys_ps = MTKParameters(sys, p, x0)
else
sys_ps = varmap_to_vars(p, parameters(sys); defaults = x0)
end
p[get_iv(sys)] = NaN
if has_index_cache(initsys) && get_index_cache(initsys) !== nothing
oldps = MTKParameters(initsys, p, merge(guesses(sys), defaults(sys), op))
initsys_ps = parameters(initsys)
p_getter = build_explicit_observed_function(
sys, initsys_ps; eval_expression, eval_module)

u_getter = isempty(unknowns(initsys)) ? (_...) -> nothing :
build_explicit_observed_function(
sys, unknowns(initsys); eval_expression, eval_module)
get_initprob_u_p = let p_getter = p_getter,
p_setter! = setp(initsys, initsys_ps),
u_getter = u_getter

function (u, p, t)
p_setter!(oldps, p_getter(u, p, t))
newu = u_getter(u, p, t)
return newu, oldps
end
end
else
get_initprob_u_p = let p_getter = getu(sys, parameters(initsys)),
u_getter = build_explicit_observed_function(
sys, unknowns(initsys); eval_expression, eval_module)

function (u, p, t)
state = ProblemState(; u, p, t)
return u_getter(
state_values(state), parameter_values(state), current_time(state)),
p_getter(state)
end
end

if initializealg === nothing
initializealg = initialize ? OverrideInit() : NoInit()
end
initfn = NonlinearFunction(initsys; eval_expression, eval_module)
initprobmap = build_explicit_observed_function(
initsys, unknowns(sys); eval_expression, eval_module)

fun, u0, p = process_SciMLProblem(ODEFunction{true, SciMLBase.FullSpecialize}, sys, op, p; t = 0.0, build_initializeprob = initializealg isa OverrideInit, allow_incomplete = true, algebraic_only = true)
prob = ODEProblem(fun, u0, (nothing, nothing), p)

ps = parameters(sys)
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
lin_fun = let diff_idxs = diff_idxs,
alge_idxs = alge_idxs,
input_idxs = input_idxs,
sts = unknowns(sys),
get_initprob_u_p = get_initprob_u_p,
fun = ODEFunction{true, SciMLBase.FullSpecialize}(
sys, unknowns(sys), ps; eval_expression, eval_module),
initfn = initfn,
initprobmap = initprobmap,
fun = fun,
prob = prob,
sys_ps = p,
h = h,
integ_cache = (similar(u0)),
chunk = ForwardDiff.Chunk(input_idxs),
sys_ps = sys_ps,
initialize = initialize,
initializealg = initializealg,
initialization_abstol = initialization_abstol,
initialization_reltol = initialization_reltol,
initialization_solver_alg = initialization_solver_alg,
sys = sys

Expand All @@ -2504,14 +2448,11 @@ function linearization_function(sys::AbstractSystem, inputs,
if u !== nothing # Handle systems without unknowns
length(sts) == length(u) ||
error("Number of unknown variables ($(length(sts))) does not match the number of input unknowns ($(length(u)))")
if initialize && !isempty(alge_idxs) # This is expensive and can be omitted if the user knows that the system is already initialized
residual = fun(u, p, t)
if norm(residual[alge_idxs]) > √(eps(eltype(residual)))
initu0, initp = get_initprob_u_p(u, p, t)
initprob = NonlinearLeastSquaresProblem(initfn, initu0, initp)
nlsol = solve(initprob, initialization_solver_alg)
u = initprobmap(state_values(nlsol), parameter_values(nlsol))
end

integ = MockIntegrator{true}(u, p, t, integ_cache)
u, p, success = SciMLBase.get_initial_values(prob, integ, fun, initializealg, Val(true); abstol = initialization_abstol, reltol = initialization_reltol, nlsolve_alg = initialization_solver_alg)
if !success
error("Initialization algorithm $(initializealg) failed with `u = $u` and `p = $p`.")
end
uf = SciMLBase.UJacobianWrapper(fun, t, p)
fg_xz = ForwardDiff.jacobian(uf, u)
Expand Down Expand Up @@ -2546,6 +2487,44 @@ function linearization_function(sys::AbstractSystem, inputs,
return lin_fun, sys
end

"""
$(TYPEDEF)

Mock `DEIntegrator` to allow using `CheckInit` without having to create a new integrator
(and consequently depend on `OrdinaryDiffEq`).

# Fields

$(TYPEDFIELDS)
"""
struct MockIntegrator{iip, U, P, T, C} <: SciMLBase.DEIntegrator{Nothing, iip, U, T}
"""
The state vector.
"""
u::U
"""
The parameter object.
"""
p::P
"""
The current time.
"""
t::T
"""
The integrator cache.
"""
cache::C
end

function MockIntegrator{iip}(u::U, p::P, t::T, cache::C) where {iip, U, P, T, C}
return MockIntegrator{iip, U, P, T, C}(u, p, t, cache)
end

SymbolicIndexingInterface.state_values(integ::MockIntegrator) = integ.u
SymbolicIndexingInterface.parameter_values(integ::MockIntegrator) = integ.p
SymbolicIndexingInterface.current_time(integ::MockIntegrator) = integ.t
SciMLBase.get_tmp_cache(integ::MockIntegrator) = integ.cache

"""
(; A, B, C, D), simplified_sys = linearize_symbolic(sys::AbstractSystem, inputs, outputs; simplify = false, allow_input_derivatives = false, kwargs...)

Expand Down
2 changes: 1 addition & 1 deletion src/systems/analysis_points.jl
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@
ap_idx = analysis_point_index(ap_sys, tf.ap)
ap_idx === nothing &&
error("Analysis point $(nameof(tf.ap)) not found in system $(nameof(sys)).")
# get the anlysis point

Check warning on line 513 in src/systems/analysis_points.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"anlysis" should be "analysis".
ap_sys_eqs = copy(get_eqs(ap_sys))
ap = ap_sys_eqs[ap_idx].rhs

Expand Down Expand Up @@ -564,7 +564,7 @@
ap_idx = analysis_point_index(ap_sys, tf.ap)
ap_idx === nothing &&
error("Analysis point $(nameof(tf.ap)) not found in system $(nameof(sys)).")
# modified quations

Check warning on line 567 in src/systems/analysis_points.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"quations" should be "equations".
ap_sys_eqs = copy(get_eqs(ap_sys))
@set! ap_sys.eqs = ap_sys_eqs
ap = ap_sys_eqs[ap_idx].rhs
Expand Down Expand Up @@ -863,7 +863,7 @@
sys, ap, args...; loop_openings = [], system_modifier = identity, kwargs...)
lin_fun, ssys = $(utility_fun)(
sys, ap, args...; loop_openings, system_modifier, kwargs...)
ModelingToolkit.linearize(ssys, lin_fun; kwargs...), ssys
ModelingToolkit.linearize(ssys, lin_fun), ssys
end
end

Expand All @@ -876,7 +876,7 @@
# Keyword Arguments

- `system_modifier`: a function which takes the modified system and returns a new system
with any required further modifications peformed.

Check warning on line 879 in src/systems/analysis_points.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"peformed" should be "performed".
"""
function open_loop(sys, ap::Union{Symbol, AnalysisPoint}; system_modifier = identity)
ap = only(canonicalize_ap(sys, ap))
Expand Down
5 changes: 3 additions & 2 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1295,6 +1295,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
check_units = true,
use_scc = true,
allow_incomplete = false,
algebraic_only = false,
kwargs...) where {iip, specialize}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
Expand All @@ -1305,12 +1306,12 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
isys = structural_simplify(
generate_initializesystem(
sys; initialization_eqs, check_units, pmap = parammap,
guesses, extra_metadata = (; use_scc)); fully_determined)
guesses, extra_metadata = (; use_scc), algebraic_only); fully_determined)
else
isys = structural_simplify(
generate_initializesystem(
sys; u0map, initialization_eqs, check_units,
pmap = parammap, guesses, extra_metadata = (; use_scc)); fully_determined)
pmap = parammap, guesses, extra_metadata = (; use_scc)), algebraic_only; fully_determined)
end

meta = get_metadata(isys)
Expand Down
11 changes: 9 additions & 2 deletions src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,8 @@ Keyword arguments:
other to attempt to arrive at a numeric value.
- `use_scc`: Whether to use `SCCNonlinearProblem` for initialization if the system is fully
determined.
- `algebraic_only`: Whether to build the initialization problem using only algebraic equations.
- `allow_incomplete`: Whether to allow incomplete initialization problems.

All other keyword arguments are passed as-is to `constructor`.
"""
Expand All @@ -674,7 +676,8 @@ function process_SciMLProblem(
symbolic_u0 = false, warn_cyclic_dependency = false,
circular_dependency_max_cycle_length = length(all_symbols(sys)),
circular_dependency_max_cycles = 10,
substitution_limit = 100, use_scc = true, kwargs...)
substitution_limit = 100, use_scc = true, algebraic_only = false,
allow_incomplete = false, kwargs...)
dvs = unknowns(sys)
ps = parameters(sys)
iv = has_iv(sys) ? get_iv(sys) : nothing
Expand All @@ -696,14 +699,18 @@ function process_SciMLProblem(

op, missing_unknowns, missing_pars = build_operating_point(
u0map, pmap, defs, cmap, dvs, ps)
filter_missing_values!(op)
filter_missing_values!(u0map)
filter_missing_values!(pmap)

if build_initializeprob
kws = maybe_build_initialization_problem(
sys, op, u0map, pmap, t, defs, guesses, missing_unknowns;
implicit_dae, warn_initialize_determined, initialization_eqs,
eval_expression, eval_module, fully_determined,
warn_cyclic_dependency, check_units = check_initialization_units,
circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc)
circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc,
algebraic_only, allow_incomplete)

kwargs = merge(kwargs, kws)
end
Expand Down
Loading