Skip to content

Commit e189861

Browse files
Merge pull request #3419 from AayushSabharwal/as/throw-no-derivative
feat: throw error when differentiating registered function with no derivative in `structural_simplify`
2 parents fe19d26 + 8c43ac5 commit e189861

File tree

4 files changed

+8
-10
lines changed

4 files changed

+8
-10
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ StochasticDelayDiffEq = "1.8.1"
151151
StochasticDiffEq = "6.72.1"
152152
SymbolicIndexingInterface = "0.3.37"
153153
SymbolicUtils = "3.14"
154-
Symbolics = "6.29.1"
154+
Symbolics = "6.29.2"
155155
URIs = "1"
156156
UnPack = "0.1, 1.0"
157157
Unitful = "1.1"

src/structural_transformation/StructuralTransformations.jl

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using Setfield: @set!, @set
44
using UnPack: @unpack
55

66
using Symbolics: unwrap, linear_expansion, fast_substitute
7+
import Symbolics
78
using SymbolicUtils
89
using SymbolicUtils.Code
910
using SymbolicUtils.Rewriters

src/structural_transformation/symbolics_tearing.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ function eq_derivative!(ts::TearingState{ODESystem}, ieq::Int; kwargs...)
6565

6666
sys = ts.sys
6767
eq = equations(ts)[ieq]
68-
eq = 0 ~ ModelingToolkit.derivative(eq.rhs - eq.lhs, get_iv(sys))
68+
eq = 0 ~ Symbolics.derivative(eq.rhs - eq.lhs, get_iv(sys); throw_no_derivative = true)
6969
push!(equations(ts), eq)
7070
# Analyze the new equation and update the graph/solvable_graph
7171
# First, copy the previous incidence and add the derivative terms.

test/split_parameters.jl

+5-8
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ end
5151

5252
get_value(interp::Interpolator, t) = interp(t)
5353
@register_symbolic get_value(interp::Interpolator, t)
54-
# get_value(data, t, dt) = data[round(Int, t / dt + 1)]
55-
# @register_symbolic get_value(data::Vector, t, dt)
54+
55+
Symbolics.derivative(::typeof(get_value), args::NTuple{2, Any}, ::Val{2}) = 0
5656

5757
function Sampled(; name, interp = Interpolator(Float64[], 0.0))
5858
pars = @parameters begin
@@ -68,11 +68,10 @@ function Sampled(; name, interp = Interpolator(Float64[], 0.0))
6868
output.u ~ get_value(interpolator, t)
6969
]
7070

71-
return ODESystem(eqs, t, vars, [interpolator]; name, systems,
72-
defaults = [output.u => interp.data[1]])
71+
return ODESystem(eqs, t, vars, [interpolator]; name, systems)
7372
end
7473

75-
vars = @variables y(t)=1 dy(t)=0 ddy(t)=0
74+
vars = @variables y(t) dy(t) ddy(t)
7675
@named src = Sampled(; interp = Interpolator(x, dt))
7776
@named int = Integrator()
7877

@@ -84,11 +83,9 @@ eqs = [y ~ src.output.u
8483
@named sys = ODESystem(eqs, t, vars, []; systems = [int, src])
8584
s = complete(sys)
8685
sys = structural_simplify(sys)
87-
@test_broken ODEProblem(
88-
sys, [], (0.0, t_end), [s.src.interpolator => Interpolator(x, dt)]; tofloat = false)
8986
prob = ODEProblem(
9087
sys, [], (0.0, t_end), [s.src.interpolator => Interpolator(x, dt)];
91-
tofloat = false, build_initializeprob = false)
88+
tofloat = false)
9289
sol = solve(prob, ImplicitEuler());
9390
@test sol.retcode == ReturnCode.Success
9491
@test sol[y][end] == x[end]

0 commit comments

Comments
 (0)