@@ -542,7 +542,7 @@ Create a Julia expression for an `ODEFunction` from the [`ODESystem`](@ref).
542
542
The arguments `dvs` and `ps` are used to set the order of the dependent
543
543
variable and parameter vectors, respectively.
544
544
"""
545
- struct ODEFunctionExpr{iip} end
545
+ struct ODEFunctionExpr{iip, specialize } end
546
546
547
547
struct ODEFunctionClosure{O, I} <: Function
548
548
f_oop:: O
551
551
(f:: ODEFunctionClosure )(u, p, t) = f. f_oop (u, p, t)
552
552
(f:: ODEFunctionClosure )(du, u, p, t) = f. f_iip (du, u, p, t)
553
553
554
- function ODEFunctionExpr {iip} (sys:: AbstractODESystem , dvs = unknowns (sys),
554
+ function ODEFunctionExpr {iip, specialize } (sys:: AbstractODESystem , dvs = unknowns (sys),
555
555
ps = parameters (sys), u0 = nothing ;
556
556
version = nothing , tgrad = false ,
557
557
jac = false , p = nothing ,
@@ -560,14 +560,12 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
560
560
steady_state = false ,
561
561
sparsity = false ,
562
562
observedfun_exp = nothing ,
563
- kwargs... ) where {iip}
563
+ kwargs... ) where {iip, specialize }
564
564
if ! iscomplete (sys)
565
565
error (" A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunctionExpr`" )
566
566
end
567
567
f_oop, f_iip = generate_function (sys, dvs, ps; expression = Val{true }, kwargs... )
568
568
569
- dict = Dict ()
570
-
571
569
fsym = gensym (:f )
572
570
_f = :($ fsym = $ ODEFunctionClosure ($ f_oop, $ f_iip))
573
571
tgradsym = gensym (:tgrad )
@@ -590,30 +588,28 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
590
588
_jac = :($ jacsym = nothing )
591
589
end
592
590
591
+ Msym = gensym (:M )
593
592
M = calculate_massmatrix (sys)
594
-
595
- _M = if sparse && ! (u0 === nothing || M === I)
596
- SparseArrays. sparse (M)
593
+ if sparse && ! (u0 === nothing || M === I)
594
+ _M = :($ Msym = $ (SparseArrays. sparse (M)))
597
595
elseif u0 === nothing || M === I
598
- M
596
+ _M = :( $ Msym = $ M)
599
597
else
600
- ArrayInterface. restructure (u0 .* u0' , M)
598
+ _M = :( $ Msym = $ ( ArrayInterface. restructure (u0 .* u0' , M)) )
601
599
end
602
600
603
601
jp_expr = sparse ? :($ similar ($ (get_jac (sys)[]), Float64)) : :nothing
604
602
ex = quote
605
- $ _f
606
- $ _tgrad
607
- $ _jac
608
- M = $ _M
609
- ODEFunction {$iip} ($ fsym,
610
- sys = $ sys,
611
- jac = $ jacsym,
612
- tgrad = $ tgradsym,
613
- mass_matrix = M,
614
- jac_prototype = $ jp_expr,
615
- sparsity = $ (sparsity ? jacobian_sparsity (sys) : nothing ),
616
- observed = $ observedfun_exp)
603
+ let $ _f, $ _tgrad, $ _jac, $ _M
604
+ ODEFunction {$iip, $specialize} ($ fsym,
605
+ sys = $ sys,
606
+ jac = $ jacsym,
607
+ tgrad = $ tgradsym,
608
+ mass_matrix = $ Msym,
609
+ jac_prototype = $ jp_expr,
610
+ sparsity = $ (sparsity ? jacobian_sparsity (sys) : nothing ),
611
+ observed = $ observedfun_exp)
612
+ end
617
613
end
618
614
! linenumbers ? Base. remove_linenums! (ex) : ex
619
615
end
@@ -622,6 +618,14 @@ function ODEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)
622
618
ODEFunctionExpr {true} (sys, args... ; kwargs... )
623
619
end
624
620
621
+ function ODEFunctionExpr {true} (sys:: AbstractODESystem , args... ; kwargs... )
622
+ return ODEFunctionExpr {true, SciMLBase.AutoSpecialize} (sys, args... ; kwargs... )
623
+ end
624
+
625
+ function ODEFunctionExpr {false} (sys:: AbstractODESystem , args... ; kwargs... )
626
+ return ODEFunctionExpr {false, SciMLBase.FullSpecialize} (sys, args... ; kwargs... )
627
+ end
628
+
625
629
"""
626
630
```julia
627
631
DAEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
0 commit comments