|
| 1 | +using ModelingToolkit |
| 2 | +using ModelingToolkit: t_nounits as t, D_nounits as D |
| 3 | + |
| 4 | +@testset "`generate_custom_function`" begin |
| 5 | + @variables x(t) y(t)[1:3] |
| 6 | + @parameters p1=1.0 p2[1:3]=[1.0, 2.0, 3.0] p3::Int=1 p4::Bool=false |
| 7 | + |
| 8 | + sys = complete(ODESystem(Equation[], t, [x; y], [p1, p2, p3, p4]; name = :sys)) |
| 9 | + u0 = [1.0, 2.0, 3.0, 4.0] |
| 10 | + p = ModelingToolkit.MTKParameters(sys, []) |
| 11 | + |
| 12 | + fn1 = generate_custom_function(sys, x + y[1] + p1 + p2[1] + p3 * t; expression = Val(false)) |
| 13 | + @test fn1(u0, p, 0.0) == 5.0 |
| 14 | + |
| 15 | + fn2 = generate_custom_function( |
| 16 | + sys, x + y[1] + p1 + p2[1] + p3 * t, [x], [p1, p2, p3]; expression = Val(false)) |
| 17 | + @test fn1(u0, p, 0.0) == 5.0 |
| 18 | + |
| 19 | + fn3_oop, fn3_iip = generate_custom_function( |
| 20 | + sys, [x + y[2], y[3] + p2[2], p1 + p3, 3t]; expression = Val(false)) |
| 21 | + |
| 22 | + buffer = zeros(4) |
| 23 | + fn3_iip(buffer, u0, p, 1.0) |
| 24 | + @test buffer == [4.0, 6.0, 2.0, 3.0] |
| 25 | + @test fn3_oop(u0, p, 1.0) == [4.0, 6.0, 2.0, 3.0] |
| 26 | + |
| 27 | + fn4 = generate_custom_function(sys, ifelse(p4, p1, p2[2]); expression = Val(false)) |
| 28 | + @test fn4(u0, p, 1.0) == 2.0 |
| 29 | + fn5 = generate_custom_function(sys, ifelse(!p4, p1, p2[2]); expression = Val(false)) |
| 30 | + @test fn5(u0, p, 1.0) == 1.0 |
| 31 | + |
| 32 | + @variables x y[1:3] |
| 33 | + sys = complete(NonlinearSystem(Equation[], [x; y], [p1, p2, p3, p4]; name = :sys)) |
| 34 | + p = MTKParameters(sys, []) |
| 35 | + |
| 36 | + fn1 = generate_custom_function(sys, x + y[1] + p1 + p2[1] + p3; expression = Val(false)) |
| 37 | + @test fn1(u0, p) == 6.0 |
| 38 | + |
| 39 | + fn2 = generate_custom_function( |
| 40 | + sys, x + y[1] + p1 + p2[1] + p3, [x], [p1, p2, p3]; expression = Val(false)) |
| 41 | + @test fn1(u0, p) == 6.0 |
| 42 | + |
| 43 | + fn3_oop, fn3_iip = generate_custom_function( |
| 44 | + sys, [x + y[2], y[3] + p2[2], p1 + p3]; expression = Val(false)) |
| 45 | + |
| 46 | + buffer = zeros(3) |
| 47 | + fn3_iip(buffer, u0, p) |
| 48 | + @test buffer == [4.0, 6.0, 2.0] |
| 49 | + @test fn3_oop(u0, p, 1.0) == [4.0, 6.0, 2.0] |
| 50 | + |
| 51 | + fn4 = generate_custom_function(sys, ifelse(p4, p1, p2[2]); expression = Val(false)) |
| 52 | + @test fn4(u0, p, 1.0) == 2.0 |
| 53 | + fn5 = generate_custom_function(sys, ifelse(!p4, p1, p2[2]); expression = Val(false)) |
| 54 | + @test fn5(u0, p, 1.0) == 1.0 |
| 55 | +end |
0 commit comments