Skip to content

Commit 02a8172

Browse files
authored
Add rules for det and logdet of Cholesky (#613)
* Add rules for `det` and `logdet` of `Cholesky` * Fix tests on Julia 1.6 * Revert restricting `det` to `StridedMatrix` * Update Project.toml * Handle Cholesky factorizations of singular matrices * Handle zero co-tangents
1 parent d246d12 commit 02a8172

File tree

3 files changed

+63
-1
lines changed

3 files changed

+63
-1
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.32.1"
3+
version = "1.33.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/LinearAlgebra/factorization.jl

+28
Original file line numberDiff line numberDiff line change
@@ -551,3 +551,31 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky}
551551
end
552552
return getproperty(F, x), getproperty_cholesky_pullback
553553
end
554+
555+
# `det` and `logdet` for `Cholesky`
556+
function rrule(::typeof(det), C::Cholesky)
557+
y = det(C)
558+
diagF = _diag_view(C.factors)
559+
function det_Cholesky_pullback(ȳ)
560+
ΔF = Diagonal(_x_divide_conj_y.(2 ** conj(y), diagF))
561+
ΔC = Tangent{typeof(C)}(; factors=ΔF)
562+
return NoTangent(), ΔC
563+
end
564+
return y, det_Cholesky_pullback
565+
end
566+
567+
function rrule(::typeof(logdet), C::Cholesky)
568+
y = logdet(C)
569+
diagF = _diag_view(C.factors)
570+
function logdet_Cholesky_pullback(ȳ)
571+
ΔC = Tangent{typeof(C)}(; factors=Diagonal(_x_divide_conj_y.(2 * ȳ, diagF)))
572+
return NoTangent(), ΔC
573+
end
574+
return y, logdet_Cholesky_pullback
575+
end
576+
577+
# Return `x / conj(y)`, or a type-stable 0 if `iszero(x)`
578+
function _x_divide_conj_y(x, y)
579+
z = x / conj(y)
580+
return iszero(x) ? zero(z) : z
581+
end

test/rulesets/LinearAlgebra/factorization.jl

+34
Original file line numberDiff line numberDiff line change
@@ -432,5 +432,39 @@ end
432432
ΔX_symmetric = chol_back_sym(Δ)[2]
433433
@test sym_back(ΔX_symmetric)[2] dX_pullback(Δ)[2]
434434
end
435+
436+
@testset "det and logdet (uplo=$p)" for p in (:U, :L)
437+
@testset "$op" for op in (det, logdet)
438+
@testset "$T" for T in (Float64, ComplexF64)
439+
n = 5
440+
# rand (not randn) so det will be postive, so logdet will be defined
441+
A = 3 * rand(T, (n, n))
442+
X = Cholesky(A * A' + I, p, 0)
443+
X̄_acc = Tangent{typeof(X)}(; factors=Diagonal(randn(T, n))) # sensitivity is always a diagonal
444+
test_rrule(op, X X̄_acc)
445+
446+
# return type
447+
_, op_pullback = rrule(op, X)
448+
= op_pullback(2.7)[2]
449+
@testisa Tangent{<:Cholesky}
450+
@test.factors isa Diagonal
451+
452+
# zero co-tangent
453+
= op_pullback(0.0)[2]
454+
@test all(iszero, X̄.factors)
455+
end
456+
end
457+
458+
@testset "singular ($T)" for T in (Float64, ComplexF64)
459+
n = 5
460+
L = LowerTriangular(randn(T, (n, n)))
461+
L[1, 1] = zero(T)
462+
X = cholesky(L * L'; check=false)
463+
detX, det_pullback = rrule(det, X)
464+
ΔX = det_pullback(rand())[2]
465+
@test iszero(detX)
466+
@test ΔX.factors isa Diagonal && all(iszero, ΔX.factors)
467+
end
468+
end
435469
end
436470
end

0 commit comments

Comments
 (0)