Skip to content

Commit a4eb3f3

Browse files
authored
Add rank and cond (#220)
1 parent 511fa90 commit a4eb3f3

File tree

4 files changed

+54
-3
lines changed

4 files changed

+54
-3
lines changed

ext/TensorKitChainRulesCoreExt/factorizations.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector
219219
Sp = view(S, 1:p)
220220

221221
# rank
222-
r = findlast(>=(tol), S)
222+
r = count(>(tol), S)
223223

224224
# compute antihermitian part of projection of ΔU and ΔV onto U and V
225225
# also already subtract this projection from ΔU and ΔV

src/TensorKit.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ export mul!, lmul!, rmul!, adjoint!, pinv, axpy!, axpby!
7272
export leftorth, rightorth, leftnull, rightnull,
7373
leftorth!, rightorth!, leftnull!, rightnull!,
7474
tsvd!, tsvd, eigen, eigen!, eig, eig!, eigh, eigh!, exp, exp!,
75-
isposdef, isposdef!, ishermitian, sylvester
75+
isposdef, isposdef!, ishermitian, sylvester, rank, cond
7676
export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition,
7777
repartition!
7878
export catdomain, catcodomain
@@ -119,7 +119,7 @@ using LinearAlgebra: norm, dot, normalize, normalize!, tr,
119119
adjoint, adjoint!, transpose, transpose!,
120120
lu, pinv, sylvester,
121121
eigen, eigen!, svd, svd!,
122-
isposdef, isposdef!, ishermitian,
122+
isposdef, isposdef!, ishermitian, rank, cond,
123123
Diagonal, Hermitian
124124

125125
using SparseArrays: SparseMatrixCSC, sparse, nzrange, rowvals, nonzeros

src/tensors/linalg.jl

+26
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,32 @@ function _norm(blockiter, p::Real, init::Real)
271271
end
272272
end
273273

274+
_default_rtol(t) = eps(real(float(scalartype(t)))) * min(dim(domain(t)), dim(codomain(t)))
275+
276+
function LinearAlgebra.rank(t::AbstractTensorMap; atol::Real=0,
277+
rtol::Real=atol > 0 ? 0 : _default_rtol(t))
278+
dim(t) == 0 && return 0
279+
S = LinearAlgebra.svdvals(t)
280+
tol = max(atol, rtol * maximum(first, values(S)))
281+
return sum(cs -> dim(cs[1]) * count(>(tol), cs[2]), S)
282+
end
283+
284+
function LinearAlgebra.cond(t::AbstractTensorMap, p::Real=2)
285+
if p == 2
286+
if dim(t) == 0
287+
domain(t) == codomain(t) ||
288+
throw(SpaceMismatch("`cond` requires domain and codomain to be the same"))
289+
return zero(real(float(scalartype(t))))
290+
end
291+
S = LinearAlgebra.svdvals(t)
292+
maxS = maximum(first, values(S))
293+
minS = minimum(last, values(S))
294+
return iszero(maxS) ? oftype(maxS, Inf) : (maxS / minS)
295+
else
296+
throw(ArgumentError("cond currently only defined for p=2"))
297+
end
298+
end
299+
274300
# TensorMap trace
275301
function LinearAlgebra.tr(t::AbstractTensorMap)
276302
domain(t) == codomain(t) ||

test/tensors.jl

+25
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,23 @@ for V in spacelist
546546
@test b s′[c]
547547
end
548548
end
549+
@testset "cond and rank" begin
550+
t2 = permute(t, ((3, 4, 2), (1, 5)))
551+
d1 = dim(codomain(t2))
552+
d2 = dim(domain(t2))
553+
@test rank(t2) == min(d1, d2)
554+
M = leftnull(t2)
555+
@test rank(M) == max(d1, d2) - min(d1, d2)
556+
t3 = unitary(T, V1 V2, V1 V2)
557+
@test cond(t3) one(real(T))
558+
@test rank(t3) == dim(V1 V2)
559+
t4 = randn(T, V1 V2, V1 V2)
560+
t4 = (t4 + t4') / 2
561+
vals = LinearAlgebra.eigvals(t4)
562+
λmax = maximum(s -> maximum(abs, s), values(vals))
563+
λmin = minimum(s -> minimum(abs, s), values(vals))
564+
@test cond(t4) λmax / λmin
565+
end
549566
end
550567
@testset "empty tensor" begin
551568
t = randn(T, V1 V2, zero(V1))
@@ -586,6 +603,13 @@ for V in spacelist
586603
@test U == t
587604
@test dim(U) == dim(S) == dim(V)
588605
end
606+
@testset "cond and rank" begin
607+
@test rank(t) == 0
608+
W2 = zero(V1) * zero(V2)
609+
t2 = rand(W2, W2)
610+
@test rank(t2) == 0
611+
@test cond(t2) == 0.0
612+
end
589613
end
590614
t = rand(T, V1 V1' V2 V2')
591615
@testset "eig and isposdef" begin
@@ -615,6 +639,7 @@ for V in spacelist
615639
@test V
616640
λ = minimum(minimum(real(LinearAlgebra.diag(b)))
617641
for (c, b) in blocks(D))
642+
@test cond(Ṽ) one(real(T))
618643
@test isposdef(t2) == isposdef(λ)
619644
@test isposdef(t2 - λ * one(t2) + 0.1 * one(t2))
620645
@test !isposdef(t2 - λ * one(t2) - 0.1 * one(t2))

0 commit comments

Comments
 (0)