Skip to content

Commit 35c8f4c

Browse files
Merge pull request #3360 from AayushSabharwal/as/refactor-codegen
refactor: centralize all code generation
2 parents e8c464f + 1fc081f commit 35c8f4c

24 files changed

+427
-572
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ StochasticDiffEq = "6.72.1"
147147
StochasticDelayDiffEq = "1.8.1"
148148
SymbolicIndexingInterface = "0.3.37"
149149
SymbolicUtils = "3.10.1"
150-
Symbolics = "6.25"
150+
Symbolics = "6.27"
151151
URIs = "1"
152152
UnPack = "0.1, 1.0"
153153
Unitful = "1.1"

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ include("systems/connectors.jl")
150150
include("systems/analysis_points.jl")
151151
include("systems/imperative_affect.jl")
152152
include("systems/callbacks.jl")
153+
include("systems/codegen_utils.jl")
153154
include("systems/problem_utils.jl")
154155

155156
include("systems/nonlinear/nonlinearsystem.jl")

src/inputoutput.jl

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -249,17 +249,8 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
249249
ddvs = map(Differential(get_iv(sys)), dvs)
250250
args = (ddvs, args...)
251251
end
252-
process = get_postprocess_fbody(sys)
253-
wrapped_arrays_vars = disturbance_argument ?
254-
wrap_array_vars(
255-
sys, rhss; dvs, ps, inputs, extra_args = (disturbance_inputs,)) :
256-
wrap_array_vars(sys, rhss; dvs, ps, inputs)
257-
f = build_function(rhss, args...; postprocess_fbody = process,
258-
expression = Val{true}, wrap_code = wrap_mtkparameters(
259-
sys, false, 3, Int(disturbance_argument) + 1) .∘
260-
wrapped_arrays_vars .∘
261-
wrap_parameter_dependencies(sys, false),
262-
kwargs...)
252+
f = build_function_wrapper(sys, rhss, args...; p_start = 3 + implicit_dae,
253+
p_end = length(p) + 2 + implicit_dae)
263254
f = eval_or_rgf.(f; eval_expression, eval_module)
264255
(; f, dvs, ps, io_sys = sys)
265256
end

src/systems/abstractsystem.jl

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -161,47 +161,27 @@ time-independent systems. If `split=true` (the default) was passed to [`complete
161161
object.
162162
"""
163163
function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys),
164-
ps = parameters(sys); wrap_code = nothing, postprocess_fbody = nothing, states = nothing,
164+
ps = parameters(sys);
165165
expression = Val{true}, eval_expression = false, eval_module = @__MODULE__,
166166
cachesyms::Tuple = (), kwargs...)
167167
if !iscomplete(sys)
168168
error("A completed system is required. Call `complete` or `structural_simplify` on the system.")
169169
end
170170
p = (reorder_parameters(sys, unwrap.(ps))..., cachesyms...)
171171
isscalar = !(exprs isa AbstractArray)
172-
if wrap_code === nothing
173-
wrap_code = isscalar ? identity : (identity, identity)
174-
end
175-
pre, sol_states = get_substitutions_and_solved_unknowns(sys, isscalar ? [exprs] : exprs)
176-
if postprocess_fbody === nothing
177-
postprocess_fbody = pre
178-
end
179-
if states === nothing
180-
states = sol_states
181-
end
182172
fnexpr = if is_time_dependent(sys)
183-
build_function(exprs,
173+
build_function_wrapper(sys, exprs,
184174
dvs,
185175
p...,
186176
get_iv(sys);
187177
kwargs...,
188-
postprocess_fbody,
189-
states,
190-
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
191-
wrap_array_vars(sys, exprs; dvs, cachesyms) .∘
192-
wrap_parameter_dependencies(sys, isscalar),
193178
expression = Val{true}
194179
)
195180
else
196-
build_function(exprs,
181+
build_function_wrapper(sys, exprs,
197182
dvs,
198183
p...;
199184
kwargs...,
200-
postprocess_fbody,
201-
states,
202-
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
203-
wrap_array_vars(sys, exprs; dvs, cachesyms) .∘
204-
wrap_parameter_dependencies(sys, isscalar),
205185
expression = Val{true}
206186
)
207187
end
@@ -844,7 +824,7 @@ end
844824

845825
function SymbolicIndexingInterface.all_symbols(sys::AbstractSystem)
846826
syms = all_variable_symbols(sys)
847-
for other in (parameter_symbols(sys), independent_variable_symbols(sys))
827+
for other in (full_parameters(sys), independent_variable_symbols(sys))
848828
isempty(other) || (syms = vcat(syms, other))
849829
end
850830
return syms
@@ -2578,7 +2558,7 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
25782558

25792559
fun_expr = generate_function(sys, sts, ps; expression = Val{true})[1]
25802560
fun = eval_or_rgf(fun_expr; eval_expression, eval_module)
2581-
dx = fun(sts, p..., t)
2561+
dx = fun(sts, p, t)
25822562

25832563
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
25842564
y = h(sts, p, t)

src/systems/callbacks.jl

Lines changed: 17 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -519,70 +519,19 @@ end
519519
# handles ensuring that affect! functions work with integrator arguments
520520
function add_integrator_header(
521521
sys::AbstractSystem, integrator = gensym(:MTKIntegrator), out = :u)
522-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
523-
function (expr)
524-
p = gensym(:p)
525-
Func(
526-
[
527-
DestructuredArgs([expr.args[1], p, expr.args[end]],
528-
integrator, inds = [:u, :p, :t])
529-
],
530-
[],
531-
Let(
532-
[DestructuredArgs([arg.name for arg in expr.args[2:(end - 1)]], p),
533-
expr.args[2:(end - 1)]...],
534-
expr.body,
535-
false)
536-
)
537-
end,
538-
function (expr)
539-
p = gensym(:p)
540-
Func(
541-
[
542-
DestructuredArgs([expr.args[1], expr.args[2], p, expr.args[end]],
543-
integrator, inds = [out, :u, :p, :t])
544-
],
545-
[],
546-
Let(
547-
[DestructuredArgs([arg.name for arg in expr.args[3:(end - 1)]], p),
548-
expr.args[3:(end - 1)]...],
549-
expr.body,
550-
false)
551-
)
552-
end
553-
else
554-
expr -> Func([DestructuredArgs(expr.args, integrator, inds = [:u, :p, :t])], [],
555-
expr.body),
556-
expr -> Func(
557-
[DestructuredArgs(expr.args, integrator, inds = [out, :u, :p, :t])], [],
558-
expr.body)
559-
end
522+
expr -> Func([DestructuredArgs(expr.args, integrator, inds = [:u, :p, :t])], [],
523+
expr.body),
524+
expr -> Func(
525+
[DestructuredArgs(expr.args, integrator, inds = [out, :u, :p, :t])], [],
526+
expr.body)
560527
end
561528

562529
function condition_header(sys::AbstractSystem, integrator = gensym(:MTKIntegrator))
563-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
564-
function (expr)
565-
p = gensym(:p)
566-
res = Func(
567-
[expr.args[1], expr.args[2],
568-
DestructuredArgs([p], integrator, inds = [:p])],
569-
[],
570-
Let(
571-
[
572-
DestructuredArgs([arg.name for arg in expr.args[3:end]], p),
573-
expr.args[3:end]...
574-
], expr.body, false
575-
)
576-
)
577-
return res
578-
end
579-
else
580-
expr -> Func(
581-
[expr.args[1], expr.args[2],
582-
DestructuredArgs(expr.args[3:end], integrator, inds = [:p])],
583-
[],
584-
expr.body)
585-
end
530+
expr -> Func(
531+
[expr.args[1], expr.args[2],
532+
DestructuredArgs(expr.args[3:end], integrator, inds = [:p])],
533+
[],
534+
expr.body)
586535
end
587536

588537
function callback_save_header(sys::AbstractSystem, cb)
@@ -628,11 +577,10 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
628577
cmap = map(x -> x => getdefault(x), cs)
629578
condit = substitute(condit, cmap)
630579
end
631-
expr = build_function(
580+
expr = build_function_wrapper(sys,
632581
condit, u, t, p...; expression = Val{true},
633-
wrap_code = condition_header(sys) .∘
634-
wrap_array_vars(sys, condit; dvs, ps, inputs = true) .∘
635-
wrap_parameter_dependencies(sys, !(condit isa AbstractArray)),
582+
p_start = 3, p_end = length(p) + 2,
583+
wrap_code = condition_header(sys),
636584
kwargs...)
637585
if expression == Val{true}
638586
return expr
@@ -715,14 +663,12 @@ function compile_affect(eqs::Vector{Equation}, cb, sys, dvs, ps; outputidxs = no
715663
end
716664
t = get_iv(sys)
717665
integ = gensym(:MTKIntegrator)
718-
pre = get_preprocess_constants(rhss)
719-
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{true},
666+
rf_oop, rf_ip = build_function_wrapper(
667+
sys, rhss, u, p..., t; expression = Val{true},
720668
wrap_code = callback_save_header(sys, cb) .∘
721-
add_integrator_header(sys, integ, outvar) .∘
722-
wrap_array_vars(sys, rhss; dvs, ps = _ps) .∘
723-
wrap_parameter_dependencies(sys, false),
669+
add_integrator_header(sys, integ, outvar),
724670
outputidxs = update_inds,
725-
postprocess_fbody = pre,
671+
create_bindings = false,
726672
kwargs...)
727673
# applied user-provided function to the generated expression
728674
if postprocess_affect_expr! !== nothing

0 commit comments

Comments
 (0)