Skip to content
Draft
28 changes: 23 additions & 5 deletions src/solutions/save_idxs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,30 @@ function as_diffeq_array(vt::Vector{VectorTemplate}, t)
return DiffEqArray(typeof(TupleOfArraysWrapper(vt))[], t, (1, 1))
end

function get_root_indp(indp)
if hasmethod(symbolic_container, Tuple{typeof(indp)}) &&
(sc = symbolic_container(indp)) !== indp
return get_root_indp(sc)
function get_root_indp(prob::AbstractSciMLProblem)
get_root_indp(prob.f)
end

function get_root_indp(f::T) where {T <: AbstractSciMLFunction}
if hasfield(T, :sys)
return f.sys
elseif hasfield(T, :f) && f.f isa AbstractSciMLFunction
return get_root_indp(f.f)
else
return nothing
end
return indp
end

function get_root_indp(prob::LinearProblem)
get_root_indp(prob.f)
end

get_root_indp(prob::AbstractJumpProblem) = get_root_indp(prob.prob)

get_root_indp(x) = x

function get_root_indp(f::SymbolicLinearInterface)
get_root_indp(f.sys)
end

# Everything from this point on is public API
Expand Down
26 changes: 26 additions & 0 deletions test/JET.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using NonlinearSolve
using LinearSolve
using LinearAlgebra
using ADTypes
using JET
const LS = LinearSolve

function f(u, p)
L, U = cholesky(p.Σ)
rhs = (u .* u .- p.λ)
linprob = LinearProblem(Matrix(L), rhs)
alg = LS.GenericLUFactorization()
sol = LinearSolve.solve(linprob, alg)
return sol.u
end

function minimize(λ=1.0)
ps = (; λ, Σ=hermitianpart(rand(2,2) + 2*I))
u₀ = rand(2)
prob = NonlinearLeastSquaresProblem{false}(f, u₀, ps)
autodiff = AutoForwardDiff(; chunksize=1)
sol = solve(prob, SimpleTrustRegion(; autodiff))
return sol.u
end

@test_opt minimize()
4 changes: 4 additions & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Expand Down
4 changes: 2 additions & 2 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ push!(syss, nsys)
push!(probs, NonlinearProblem(nsys, [u0; p], jac = true))

rate₁ = β * x * y
affect₁ = [x ~ x - σ, y ~ y + σ]
affect₁ = [x ~ Pre(x) - σ, y ~ Pre(y) + σ]
rate₂ = ρ * y
affect₂ = [y ~ y - 1, z ~ z + 1]
affect₂ = [y ~ Pre(y) - 1, z ~ Pre(z) + 1]
j₁ = ConstantRateJump(rate₁, affect₁)
j₂ = ConstantRateJump(rate₂, affect₂)
j₃ = MassActionJump(2 * β + ρ, [z => 1], [x => 1, z => -1])
Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ end
@time @safetestset "Aqua" begin
include("aqua.jl")
end
activate_downstream_env()
@time @safetestset "JET" begin
include("JET.jl")
end
end
if GROUP == "Core" || GROUP == "All"
@time @safetestset "Display" begin
Expand Down
Loading