Skip to content

Commit f5ea344

Browse files
Merge pull request #3373 from SciML/dw/odefunctionexpr_specialize
Support specialization in `ODEFunctionExpr`
2 parents 387df59 + 58dde09 commit f5ea344

File tree

2 files changed

+45
-22
lines changed

2 files changed

+45
-22
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ Create a Julia expression for an `ODEFunction` from the [`ODESystem`](@ref).
542542
The arguments `dvs` and `ps` are used to set the order of the dependent
543543
variable and parameter vectors, respectively.
544544
"""
545-
struct ODEFunctionExpr{iip} end
545+
struct ODEFunctionExpr{iip, specialize} end
546546

547547
struct ODEFunctionClosure{O, I} <: Function
548548
f_oop::O
@@ -551,7 +551,7 @@ end
551551
(f::ODEFunctionClosure)(u, p, t) = f.f_oop(u, p, t)
552552
(f::ODEFunctionClosure)(du, u, p, t) = f.f_iip(du, u, p, t)
553553

554-
function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
554+
function ODEFunctionExpr{iip, specialize}(sys::AbstractODESystem, dvs = unknowns(sys),
555555
ps = parameters(sys), u0 = nothing;
556556
version = nothing, tgrad = false,
557557
jac = false, p = nothing,
@@ -560,14 +560,12 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
560560
steady_state = false,
561561
sparsity = false,
562562
observedfun_exp = nothing,
563-
kwargs...) where {iip}
563+
kwargs...) where {iip, specialize}
564564
if !iscomplete(sys)
565565
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunctionExpr`")
566566
end
567567
f_oop, f_iip = generate_function(sys, dvs, ps; expression = Val{true}, kwargs...)
568568

569-
dict = Dict()
570-
571569
fsym = gensym(:f)
572570
_f = :($fsym = $ODEFunctionClosure($f_oop, $f_iip))
573571
tgradsym = gensym(:tgrad)
@@ -590,30 +588,28 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
590588
_jac = :($jacsym = nothing)
591589
end
592590

591+
Msym = gensym(:M)
593592
M = calculate_massmatrix(sys)
594-
595-
_M = if sparse && !(u0 === nothing || M === I)
596-
SparseArrays.sparse(M)
593+
if sparse && !(u0 === nothing || M === I)
594+
_M = :($Msym = $(SparseArrays.sparse(M)))
597595
elseif u0 === nothing || M === I
598-
M
596+
_M = :($Msym = $M)
599597
else
600-
ArrayInterface.restructure(u0 .* u0', M)
598+
_M = :($Msym = $(ArrayInterface.restructure(u0 .* u0', M)))
601599
end
602600

603601
jp_expr = sparse ? :($similar($(get_jac(sys)[]), Float64)) : :nothing
604602
ex = quote
605-
$_f
606-
$_tgrad
607-
$_jac
608-
M = $_M
609-
ODEFunction{$iip}($fsym,
610-
sys = $sys,
611-
jac = $jacsym,
612-
tgrad = $tgradsym,
613-
mass_matrix = M,
614-
jac_prototype = $jp_expr,
615-
sparsity = $(sparsity ? jacobian_sparsity(sys) : nothing),
616-
observed = $observedfun_exp)
603+
let $_f, $_tgrad, $_jac, $_M
604+
ODEFunction{$iip, $specialize}($fsym,
605+
sys = $sys,
606+
jac = $jacsym,
607+
tgrad = $tgradsym,
608+
mass_matrix = $Msym,
609+
jac_prototype = $jp_expr,
610+
sparsity = $(sparsity ? jacobian_sparsity(sys) : nothing),
611+
observed = $observedfun_exp)
612+
end
617613
end
618614
!linenumbers ? Base.remove_linenums!(ex) : ex
619615
end
@@ -622,6 +618,14 @@ function ODEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)
622618
ODEFunctionExpr{true}(sys, args...; kwargs...)
623619
end
624620

621+
function ODEFunctionExpr{true}(sys::AbstractODESystem, args...; kwargs...)
622+
return ODEFunctionExpr{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
623+
end
624+
625+
function ODEFunctionExpr{false}(sys::AbstractODESystem, args...; kwargs...)
626+
return ODEFunctionExpr{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
627+
end
628+
625629
"""
626630
```julia
627631
DAEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),

test/odesystem.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,25 @@ f.f(du, u, p, 0.1)
9797
@test du == [4, 0, -16]
9898
@test_throws ArgumentError f.f(u, p, 0.1)
9999

100+
#check iip
101+
f = eval(ODEFunctionExpr(de, [x, y, z], [σ, ρ, β]))
102+
f2 = ODEFunction(de, [x, y, z], [σ, ρ, β])
103+
@test SciMLBase.isinplace(f) === SciMLBase.isinplace(f2)
104+
@test SciMLBase.specialization(f) === SciMLBase.specialization(f2)
105+
for iip in (true, false)
106+
f = eval(ODEFunctionExpr{iip}(de, [x, y, z], [σ, ρ, β]))
107+
f2 = ODEFunction{iip}(de, [x, y, z], [σ, ρ, β])
108+
@test SciMLBase.isinplace(f) === SciMLBase.isinplace(f2) === iip
109+
@test SciMLBase.specialization(f) === SciMLBase.specialization(f2)
110+
111+
for specialize in (SciMLBase.AutoSpecialize, SciMLBase.FullSpecialize)
112+
f = eval(ODEFunctionExpr{iip, specialize}(de, [x, y, z], [σ, ρ, β]))
113+
f2 = ODEFunction{iip, specialize}(de, [x, y, z], [σ, ρ, β])
114+
@test SciMLBase.isinplace(f) === SciMLBase.isinplace(f2) === iip
115+
@test SciMLBase.specialization(f) === SciMLBase.specialization(f2) === specialize
116+
end
117+
end
118+
100119
#check sparsity
101120
f = eval(ODEFunctionExpr(de, [x, y, z], [σ, ρ, β], sparsity = true))
102121
@test f.sparsity == ModelingToolkit.jacobian_sparsity(de)

0 commit comments

Comments
 (0)