Skip to content

Commit 297e724

Browse files
mtfishmanlkdvos
andauthored
Make SVD more general to accommodate graded arrays (#114)
Co-authored-by: Lukas Devos <[email protected]>
1 parent 6c05eb9 commit 297e724

File tree

4 files changed

+76
-34
lines changed

4 files changed

+76
-34
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.5.3"
4+
version = "0.5.4"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/BlockArraysExtensions/blockedunitrange.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,32 @@ using BlockArrays:
99
BlockSlice,
1010
BlockVector,
1111
block,
12+
blockedrange,
1213
blockindex,
14+
blocklengths,
1315
findblock,
1416
findblockindex,
1517
mortar
1618

19+
# Get the axes of each block of a block array.
20+
function eachblockaxes(a::AbstractArray)
21+
return map(axes, blocks(a))
22+
end
23+
24+
axis(a::AbstractVector) = axes(a, 1)
25+
26+
# Get the axis of each block of a blocked unit
27+
# range.
28+
function eachblockaxis(a::AbstractVector)
29+
return map(axis, blocks(a))
30+
end
31+
32+
# Take a collection of axes and mortar them
33+
# into a single blocked axis.
34+
function mortar_axis(axs)
35+
return blockedrange(length.(axs))
36+
end
37+
1738
# Custom `BlockedUnitRange` constructor that takes a unit range
1839
# and a set of block lengths, similar to `BlockArray(::AbstractArray, blocklengths...)`.
1940
function blockedunitrange(a::AbstractUnitRange, blocklengths)

src/factorizations/svd.jl

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,19 @@ function MatrixAlgebraKit.default_svd_algorithm(A::AbstractBlockSparseMatrix; kw
2121
return BlockPermutedDiagonalAlgorithm(alg)
2222
end
2323

24-
# TODO: this should be replaced with a more general similar function that can handle setting
25-
# the blocktype and element type - something like S = similar(A, BlockType(...))
26-
function _similar_S(A::AbstractBlockSparseMatrix, s_axis)
24+
function similar_output(
25+
::typeof(svd_compact!),
26+
A,
27+
s_axis::AbstractUnitRange,
28+
alg::MatrixAlgebraKit.AbstractAlgorithm,
29+
)
30+
U = similar(A, axes(A, 1), s_axis)
2731
T = real(eltype(A))
28-
return BlockSparseArray{T,2,Diagonal{T,Vector{T}}}(undef, (s_axis, s_axis))
32+
# TODO: this should be replaced with a more general similar function that can handle setting
33+
# the blocktype and element type - something like S = similar(A, BlockType(...))
34+
S = BlockSparseMatrix{T,Diagonal{T,Vector{T}}}(undef, (s_axis, s_axis))
35+
Vt = similar(A, s_axis, axes(A, 2))
36+
return U, S, Vt
2937
end
3038

3139
function MatrixAlgebraKit.initialize_output(
@@ -34,33 +42,29 @@ function MatrixAlgebraKit.initialize_output(
3442
bm, bn = blocksize(A)
3543
bmn = min(bm, bn)
3644

37-
brows = blocklengths(axes(A, 1))
38-
bcols = blocklengths(axes(A, 2))
39-
slengths = Vector{Int}(undef, bmn)
45+
brows = eachblockaxis(axes(A, 1))
46+
bcols = eachblockaxis(axes(A, 2))
47+
s_axes = similar(brows, bmn)
4048

4149
# fill in values for blocks that are present
4250
bIs = collect(eachblockstoredindex(A))
4351
browIs = Int.(first.(Tuple.(bIs)))
4452
bcolIs = Int.(last.(Tuple.(bIs)))
4553
for bI in eachblockstoredindex(A)
4654
row, col = Int.(Tuple(bI))
47-
nrows = brows[row]
48-
ncols = bcols[col]
49-
slengths[col] = min(nrows, ncols)
55+
s_axes[col] = argmin(length, (brows[row], bcols[col]))
5056
end
5157

5258
# fill in values for blocks that aren't present, pairing them in order of occurence
5359
# this is a convention, which at least gives the expected results for blockdiagonal
5460
emptyrows = setdiff(1:bm, browIs)
5561
emptycols = setdiff(1:bn, bcolIs)
5662
for (row, col) in zip(emptyrows, emptycols)
57-
slengths[col] = min(brows[row], bcols[col])
63+
s_axes[col] = argmin(length, (brows[row], bcols[col]))
5864
end
5965

60-
s_axis = blockedrange(slengths)
61-
U = similar(A, axes(A, 1), s_axis)
62-
S = _similar_S(A, s_axis)
63-
Vt = similar(A, s_axis, axes(A, 2))
66+
s_axis = mortar_axis(s_axes)
67+
U, S, Vt = similar_output(svd_compact!, A, s_axis, alg)
6468

6569
# allocate output
6670
for bI in eachblockstoredindex(A)
@@ -79,40 +83,46 @@ function MatrixAlgebraKit.initialize_output(
7983
return U, S, Vt
8084
end
8185

86+
function similar_output(
87+
::typeof(svd_full!), A, s_axis::AbstractUnitRange, alg::MatrixAlgebraKit.AbstractAlgorithm
88+
)
89+
U = similar(A, axes(A, 1), s_axis)
90+
T = real(eltype(A))
91+
S = similar(A, T, (s_axis, axes(A, 2)))
92+
Vt = similar(A, axes(A, 2), axes(A, 2))
93+
return U, S, Vt
94+
end
95+
8296
function MatrixAlgebraKit.initialize_output(
8397
::typeof(svd_full!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
8498
)
8599
bm, bn = blocksize(A)
86100

87-
brows = blocklengths(axes(A, 1))
88-
slengths = copy(brows)
101+
brows = eachblockaxis(axes(A, 1))
102+
s_axes = similar(brows)
89103

90104
# fill in values for blocks that are present
91105
bIs = collect(eachblockstoredindex(A))
92106
browIs = Int.(first.(Tuple.(bIs)))
93107
bcolIs = Int.(last.(Tuple.(bIs)))
94108
for bI in eachblockstoredindex(A)
95109
row, col = Int.(Tuple(bI))
96-
nrows = brows[row]
97-
slengths[col] = nrows
110+
s_axes[col] = brows[row]
98111
end
99112

100113
# fill in values for blocks that aren't present, pairing them in order of occurence
101114
# this is a convention, which at least gives the expected results for blockdiagonal
102115
emptyrows = setdiff(1:bm, browIs)
103116
emptycols = setdiff(1:bn, bcolIs)
104117
for (row, col) in zip(emptyrows, emptycols)
105-
slengths[col] = brows[row]
118+
s_axes[col] = brows[row]
106119
end
107120
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
108-
slengths[bn + i] = brows[emptyrows[k]]
121+
s_axes[bn + i] = brows[emptyrows[k]]
109122
end
110123

111-
s_axis = blockedrange(slengths)
112-
U = similar(A, axes(A, 1), s_axis)
113-
Tr = real(eltype(A))
114-
S = similar(A, Tr, (s_axis, axes(A, 2)))
115-
Vt = similar(A, axes(A, 2), axes(A, 2))
124+
s_axis = mortar_axis(s_axes)
125+
U, S, Vt = similar_output(svd_full!, A, s_axis, alg)
116126

117127
# allocate output
118128
for bI in eachblockstoredindex(A)

src/factorizations/truncation.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,22 @@ function MatrixAlgebraKit.findtruncated(
4545
return indexmask
4646
end
4747

48+
function similar_truncate(
49+
::typeof(svd_trunc!),
50+
(U, S, Vᴴ)::TBlockUSVᴴ,
51+
strategy::BlockPermutedDiagonalTruncationStrategy,
52+
indexmask=MatrixAlgebraKit.findtruncated(diagview(S), strategy),
53+
)
54+
ax = axes(S, 1)
55+
counter = Base.Fix1(count, Base.Fix1(getindex, indexmask))
56+
s_lengths = filter!(>(0), map(counter, blocks(ax)))
57+
s_axis = blockedrange(s_lengths)
58+
= similar(U, axes(U, 1), s_axis)
59+
= similar(S, s_axis, s_axis)
60+
Ṽᴴ = similar(Vᴴ, s_axis, axes(Vᴴ, 2))
61+
return Ũ, S̃, Ṽᴴ
62+
end
63+
4864
function MatrixAlgebraKit.truncate!(
4965
::typeof(svd_trunc!),
5066
(U, S, Vᴴ)::TBlockUSVᴴ,
@@ -54,13 +70,7 @@ function MatrixAlgebraKit.truncate!(
5470

5571
# first determine the block structure of the output to avoid having assumptions on the
5672
# data structures
57-
ax = axes(S, 1)
58-
counter = Base.Fix1(count, Base.Fix1(getindex, indexmask))
59-
Slengths = filter!(>(0), map(counter, blocks(ax)))
60-
Sax = blockedrange(Slengths)
61-
= similar(U, axes(U, 1), Sax)
62-
= similar(S, Sax, Sax)
63-
Ṽᴴ = similar(Vᴴ, Sax, axes(Vᴴ, 2))
73+
Ũ, S̃, Ṽᴴ = similar_truncate(svd_trunc!, (U, S, Vᴴ), strategy, indexmask)
6474

6575
# then loop over the blocks and assign the data
6676
# TODO: figure out if we can presort and loop over the blocks -
@@ -70,6 +80,7 @@ function MatrixAlgebraKit.truncate!(
7080
bI_Vᴴs = collect(eachblockstoredindex(Vᴴ))
7181

7282
I′ = 0 # number of skipped blocks that got fully truncated
83+
ax = axes(S, 1)
7384
for I in 1:blocksize(ax, 1)
7485
b = ax[Block(I)]
7586
mask = indexmask[b]

0 commit comments

Comments
 (0)