Skip to content

Commit 2cfb62a

Browse files
Make length of Partials known at compile time in ForwardDiff overloads (#727)
* make sure that Partials length is known * add Dual problem JET tests * Update test/nopre/jet.jl * use dual_prob --------- Co-authored-by: Christopher Rackauckas <[email protected]>
1 parent 2f41c8e commit 2cfb62a

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ const DualBLinearProblem = LinearProblem{
3434
const DualAbstractLinearProblem = Union{
3535
DualLinearProblem, DualALinearProblem, DualBLinearProblem}
3636

37-
LinearSolve.@concrete mutable struct DualLinearCache{DT <: Dual}
37+
LinearSolve.@concrete mutable struct DualLinearCache{DT}
3838
linear_cache
3939

4040
partials_A
@@ -113,10 +113,10 @@ function linearsolve_dual_solution(
113113
end
114114

115115
function linearsolve_dual_solution(u::AbstractArray, partials,
116-
cache::DualLinearCache{DT}) where {DT}
116+
cache::DualLinearCache{DT}) where {T, V, N, DT <: Dual{T,V,N}}
117117
# Handle single-level duals for arrays
118118
partials_list = RecursiveArrayTools.VectorOfArray(partials)
119-
return map(((uᵢ, pᵢ),) -> DT(uᵢ, Partials(Tuple(pᵢ))),
119+
return map(((uᵢ, pᵢ),) -> DT(uᵢ, Partials{N,V}(NTuple{N,V}(pᵢ))),
120120
zip(u, partials_list[i, :] for i in 1:length(partials_list.u[1])))
121121
end
122122

test/nopre/jet.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using LinearSolve, RecursiveFactorization, LinearAlgebra, SparseArrays, Test
1+
using LinearSolve, ForwardDiff, RecursiveFactorization, LinearAlgebra, SparseArrays, Test
22
using JET
33

44
# Dense problem setup
@@ -22,6 +22,18 @@ prob_sparse = LinearProblem(A_sparse, b)
2222
A_sparse_spd = sparse(A_spd)
2323
prob_sparse_spd = LinearProblem(A_sparse_spd, b)
2424

25+
# Dual problem set up
26+
function h(p)
27+
(A = [p[1] p[2]+1 p[2]^3;
28+
3*p[1] p[1]+5 p[2] * p[1]-4;
29+
p[2]^2 9*p[1] p[2]],
30+
b = [p[1] + 1, p[2] * 2, p[1]^2])
31+
end
32+
33+
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
34+
35+
dual_prob = LinearProblem(A, b)
36+
2537
@testset "JET Tests for Dense Factorizations" begin
2638
# Working tests - these pass JET optimization checks
2739
JET.@test_opt init(prob, nothing)
@@ -109,3 +121,11 @@ end
109121
JET.@test_opt solve(prob) broken=true
110122
JET.@test_opt solve(prob_sparse) broken=true
111123
end
124+
125+
@testset "JET Tests for creating Dual solutions" begin
126+
# Make sure there's no runtime dispatch when making solutions of Dual problems
127+
dual_cache = init(dual_prob)
128+
ext = Base.get_extension(LinearSolve, :LinearSolveForwardDiffExt)
129+
JET.@test_opt ext.linearsolve_dual_solution(
130+
[1.0, 1.0, 1.0], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dual_cache)
131+
end

0 commit comments

Comments
 (0)