Skip to content

Commit 9eb742b

Browse files
authored
Add truncation functionality for SVD (#113)
1 parent 6c92d7c commit 9eb742b

File tree

5 files changed

+178
-3
lines changed

5 files changed

+178
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.5.0"
4+
version = "0.5.1"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/BlockSparseArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,6 @@ include("BlockArraysSparseArraysBaseExt/BlockArraysSparseArraysBaseExt.jl")
4444

4545
# factorizations
4646
include("factorizations/svd.jl")
47+
include("factorizations/truncation.jl")
4748

4849
end

src/factorizations/svd.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ using MatrixAlgebraKit: MatrixAlgebraKit, svd_compact!, svd_full!
44
BlockPermutedDiagonalAlgorithm(A::MatrixAlgebraKit.AbstractAlgorithm)
55
66
A wrapper for `MatrixAlgebraKit.AbstractAlgorithm` that implements the wrapped algorithm on
7-
a block-by-block basis, which is possible if the input matrix is a block-diagonal matrix or a block permuted block-diagonal matrix.
7+
a block-by-block basis, which is possible if the input matrix is a block-diagonal matrix or
8+
a block permuted block-diagonal matrix.
89
"""
910
struct BlockPermutedDiagonalAlgorithm{A<:MatrixAlgebraKit.AbstractAlgorithm} <:
1011
MatrixAlgebraKit.AbstractAlgorithm

src/factorizations/truncation.jl

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
using MatrixAlgebraKit: TruncationStrategy, diagview, svd_trunc!
2+
3+
function MatrixAlgebraKit.diagview(A::BlockSparseMatrix{T,Diagonal{T,Vector{T}}}) where {T}
4+
D = BlockSparseVector{T}(undef, axes(A, 1))
5+
for I in eachblockstoredindex(A)
6+
if ==(Int.(Tuple(I))...)
7+
D[Tuple(I)[1]] = diagview(A[I])
8+
end
9+
end
10+
return D
11+
end
12+
13+
"""
14+
BlockPermutedDiagonalTruncationStrategy(strategy::TruncationStrategy)
15+
16+
A wrapper for `TruncationStrategy` that implements the wrapped strategy on a block-by-block
17+
basis, which is possible if the input matrix is a block-diagonal matrix or a block permuted
18+
block-diagonal matrix.
19+
"""
20+
struct BlockPermutedDiagonalTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy
21+
strategy::T
22+
end
23+
24+
const TBlockUSVᴴ = Tuple{
25+
<:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix
26+
}
27+
28+
function MatrixAlgebraKit.truncate!(
29+
::typeof(svd_trunc!), (U, S, Vᴴ)::TBlockUSVᴴ, strategy::TruncationStrategy
30+
)
31+
# TODO assert blockdiagonal
32+
return MatrixAlgebraKit.truncate!(
33+
svd_trunc!, (U, S, Vᴴ), BlockPermutedDiagonalTruncationStrategy(strategy)
34+
)
35+
end
36+
37+
# cannot use regular slicing here: I want to slice without altering blockstructure
38+
# solution: use boolean indexing and slice the mask, effectively cheaply inverting the map
39+
function MatrixAlgebraKit.findtruncated(
40+
values::AbstractVector, strategy::BlockPermutedDiagonalTruncationStrategy
41+
)
42+
ind = MatrixAlgebraKit.findtruncated(values, strategy.strategy)
43+
indexmask = falses(length(values))
44+
indexmask[ind] .= true
45+
return indexmask
46+
end
47+
48+
function MatrixAlgebraKit.truncate!(
49+
::typeof(svd_trunc!),
50+
(U, S, Vᴴ)::TBlockUSVᴴ,
51+
strategy::BlockPermutedDiagonalTruncationStrategy,
52+
)
53+
indexmask = MatrixAlgebraKit.findtruncated(diagview(S), strategy)
54+
55+
# first determine the block structure of the output to avoid having assumptions on the
56+
# data structures
57+
ax = axes(S, 1)
58+
counter = Base.Fix1(count, Base.Fix1(getindex, indexmask))
59+
Slengths = filter!(>(0), map(counter, blocks(ax)))
60+
Sax = blockedrange(Slengths)
61+
= similar(U, axes(U, 1), Sax)
62+
= similar(S, Sax, Sax)
63+
Ṽᴴ = similar(Vᴴ, Sax, axes(Vᴴ, 2))
64+
65+
# then loop over the blocks and assign the data
66+
# TODO: figure out if we can presort and loop over the blocks -
67+
# for now this has issues with missing blocks
68+
bI_Us = collect(eachblockstoredindex(U))
69+
bI_Ss = collect(eachblockstoredindex(S))
70+
bI_Vᴴs = collect(eachblockstoredindex(Vᴴ))
71+
72+
I′ = 0 # number of skipped blocks that got fully truncated
73+
for I in 1:blocksize(ax, 1)
74+
b = ax[Block(I)]
75+
mask = indexmask[b]
76+
77+
if !any(mask)
78+
I′ += 1
79+
continue
80+
end
81+
82+
bU_id = @something findfirst(x -> last(Tuple(x)) == Block(I), bI_Us) error(
83+
"No U-block found for $I"
84+
)
85+
bU = Tuple(bI_Us[bU_id])
86+
Ũ[bU[1], bU[2] - Block(I′)] = view(U, bU...)[:, mask]
87+
88+
bVᴴ_id = @something findfirst(x -> first(Tuple(x)) == Block(I), bI_Vᴴs) error(
89+
"No Vᴴ-block found for $I"
90+
)
91+
bVᴴ = Tuple(bI_Vᴴs[bVᴴ_id])
92+
Ṽᴴ[bVᴴ[1] - Block(I′), bVᴴ[2]] = view(Vᴴ, bVᴴ...)[mask, :]
93+
94+
bS_id = findfirst(x -> last(Tuple(x)) == Block(I), bI_Ss)
95+
if !isnothing(bS_id)
96+
bS = Tuple(bI_Ss[bS_id])
97+
S̃[(bS .- Block(I′))...] = Diagonal(diagview(view(S, bS...))[mask])
98+
end
99+
end
100+
101+
return Ũ, S̃, Ṽᴴ
102+
end

test/test_factorizations.jl

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar
22
using BlockSparseArrays: BlockSparseArray, BlockDiagonal, eachblockstoredindex
3-
using MatrixAlgebraKit: svd_compact, svd_full
3+
using MatrixAlgebraKit: svd_compact, svd_full, svd_trunc, truncrank, trunctol
44
using LinearAlgebra: LinearAlgebra
55
using Random: Random
66
using Test: @inferred, @testset, @test
@@ -83,3 +83,74 @@ end
8383
usv = svd_full(c)
8484
@test test_svd(c, usv; full=true)
8585
end
86+
87+
# svd_trunc!
88+
# ----------
89+
90+
@testset "svd_trunc ($m, $n) BlockSparseMatri{$T}" for ((m, n), T) in test_params
91+
a = BlockSparseArray{T}(undef, m, n)
92+
93+
# test blockdiagonal
94+
for i in LinearAlgebra.diagind(blocks(a))
95+
I = CartesianIndices(blocks(a))[i]
96+
a[Block(I.I...)] = rand(T, size(blocks(a)[i]))
97+
end
98+
99+
minmn = min(size(a)...)
100+
r = max(1, minmn - 2)
101+
trunc = truncrank(r)
102+
103+
U1, S1, V1ᴴ = svd_trunc(a; trunc)
104+
U2, S2, V2ᴴ = svd_trunc(Matrix(a); trunc)
105+
@test size(U1) == size(U2)
106+
@test size(S1) == size(S2)
107+
@test size(V1ᴴ) == size(V2ᴴ)
108+
@test Matrix(U1 * S1 * V1ᴴ) U2 * S2 * V2ᴴ
109+
110+
@test (U1' * U1 LinearAlgebra.I)
111+
@test (V1ᴴ * V1ᴴ' LinearAlgebra.I)
112+
113+
atol = minimum(LinearAlgebra.diag(S1)) + 10 * eps(real(T))
114+
trunc = trunctol(atol)
115+
116+
U1, S1, V1ᴴ = svd_trunc(a; trunc)
117+
U2, S2, V2ᴴ = svd_trunc(Matrix(a); trunc)
118+
@test size(U1) == size(U2)
119+
@test size(S1) == size(S2)
120+
@test size(V1ᴴ) == size(V2ᴴ)
121+
@test Matrix(U1 * S1 * V1ᴴ) U2 * S2 * V2ᴴ
122+
123+
@test (U1' * U1 LinearAlgebra.I)
124+
@test (V1ᴴ * V1ᴴ' LinearAlgebra.I)
125+
126+
# test permuted blockdiagonal
127+
perm = Random.randperm(length(m))
128+
b = a[Block.(perm), Block.(1:length(n))]
129+
for trunc in (truncrank(r), trunctol(atol))
130+
U1, S1, V1ᴴ = svd_trunc(b; trunc)
131+
U2, S2, V2ᴴ = svd_trunc(Matrix(b); trunc)
132+
@test size(U1) == size(U2)
133+
@test size(S1) == size(S2)
134+
@test size(V1ᴴ) == size(V2ᴴ)
135+
@test Matrix(U1 * S1 * V1ᴴ) U2 * S2 * V2ᴴ
136+
137+
@test (U1' * U1 LinearAlgebra.I)
138+
@test (V1ᴴ * V1ᴴ' LinearAlgebra.I)
139+
end
140+
141+
# test permuted blockdiagonal with missing row/col
142+
I_removed = rand(eachblockstoredindex(b))
143+
c = copy(b)
144+
delete!(blocks(c).storage, CartesianIndex(Int.(Tuple(I_removed))))
145+
for trunc in (truncrank(r), trunctol(atol))
146+
U1, S1, V1ᴴ = svd_trunc(c; trunc)
147+
U2, S2, V2ᴴ = svd_trunc(Matrix(c); trunc)
148+
@test size(U1) == size(U2)
149+
@test size(S1) == size(S2)
150+
@test size(V1ᴴ) == size(V2ᴴ)
151+
@test Matrix(U1 * S1 * V1ᴴ) U2 * S2 * V2ᴴ
152+
153+
@test (U1' * U1 LinearAlgebra.I)
154+
@test (V1ᴴ * V1ᴴ' LinearAlgebra.I)
155+
end
156+
end

0 commit comments

Comments
 (0)