Skip to content

Commit 884a0fb

Browse files
refactor: use build_function_wrapper in SCCNonlinearProblem
1 parent dfc49b4 commit 884a0fb

File tree

1 file changed

+11
-24
lines changed

1 file changed

+11
-24
lines changed

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ struct CacheWriter{F}
563563
end
564564

565565
function (cw::CacheWriter)(p, sols)
566-
cw.fn(p.caches, sols, p...)
566+
cw.fn(p.caches, sols, p)
567567
end
568568

569569
function CacheWriter(sys::AbstractSystem, buffer_types::Vector{TypeT},
@@ -572,22 +572,15 @@ function CacheWriter(sys::AbstractSystem, buffer_types::Vector{TypeT},
572572
ps = parameters(sys)
573573
rps = reorder_parameters(sys, ps)
574574
obs_assigns = [eq.lhs eq.rhs for eq in obseqs]
575-
cmap, cs = get_cmap(sys)
576-
cmap_assigns = [eq.lhs eq.rhs for eq in cmap]
577-
578-
outsyms = [Symbol(:out, i) for i in eachindex(buffer_types)]
579575
body = map(eachindex(buffer_types), buffer_types) do i, T
580576
Symbol(:tmp, i) SetArray(true, :(out[$i]), get(exprs, T, []))
581577
end
582-
fn = Func(
583-
[:out, DestructuredArgs(DestructuredArgs.(solsyms)),
584-
DestructuredArgs.(rps)...],
585-
[],
586-
Let(body, :())
587-
) |> wrap_assignments(false, obs_assigns)[2] |>
588-
wrap_parameter_dependencies(sys, false)[2] |>
589-
wrap_array_vars(sys, []; dvs = nothing, inputs = [])[2] |>
590-
wrap_assignments(false, cmap_assigns)[2] |> toexpr
578+
579+
fn = build_function_wrapper(
580+
sys, nothing, :out, DestructuredArgs(DestructuredArgs.(solsyms)),
581+
DestructuredArgs.(rps)...; p_start = 3, p_end = length(rps) + 2,
582+
expression = Val{true}, add_observed = false,
583+
extra_assignments = [obs_assigns; body])
591584
return CacheWriter(eval_or_rgf(fn; eval_expression, eval_module))
592585
end
593586

@@ -601,21 +594,15 @@ function SCCNonlinearFunction{iip}(
601594

602595
obs_assignments = [eq.lhs eq.rhs for eq in _obs]
603596

604-
cmap, cs = get_cmap(sys)
605-
cmap_assignments = [eq.lhs eq.rhs for eq in cmap]
606597
rhss = [eq.rhs - eq.lhs for eq in _eqs]
607-
wrap_code = wrap_assignments(false, cmap_assignments) .∘
608-
(wrap_array_vars(sys, rhss; dvs = _dvs, cachesyms)) .∘
609-
wrap_parameter_dependencies(sys, false) .∘
610-
wrap_assignments(false, obs_assignments)
611-
f_gen = build_function(
612-
rhss, _dvs, rps..., cachesyms...; wrap_code, expression = Val{true})
598+
f_gen = build_function_wrapper(sys,
599+
rhss, _dvs, rps..., cachesyms...; p_start = 2,
600+
p_end = length(rps) + length(cachesyms) + 1, add_observed = false,
601+
extra_assignments = obs_assignments, expression = Val{true})
613602
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
614603

615604
f(u, p) = f_oop(u, p)
616-
f(u, p::MTKParameters) = f_oop(u, p...)
617605
f(resid, u, p) = f_iip(resid, u, p)
618-
f(resid, u, p::MTKParameters) = f_iip(resid, u, p...)
619606

620607
subsys = NonlinearSystem(_eqs, _dvs, ps; observed = _obs,
621608
parameter_dependencies = parameter_dependencies(sys), name = nameof(sys))

0 commit comments

Comments
 (0)