Skip to content

Commit 2120b7a

Browse files
authored
Make SVD a bit more customizable (#121)
1 parent 841bdbf commit 2120b7a

File tree

2 files changed

+28
-23
lines changed

2 files changed

+28
-23
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.6.0"
4+
version = "0.6.1"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/factorizations/svd.jl

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,14 @@ function MatrixAlgebraKit.default_algorithm(
3434
end
3535

3636
function similar_output(
37-
::typeof(svd_compact!),
38-
A,
39-
s_axis::AbstractUnitRange,
40-
alg::MatrixAlgebraKit.AbstractAlgorithm,
37+
::typeof(svd_compact!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm
4138
)
42-
U = similar(A, axes(A, 1), s_axis)
39+
U = similar(A, axes(A, 1), S_axes[1])
4340
T = real(eltype(A))
4441
# TODO: this should be replaced with a more general similar function that can handle setting
4542
# the blocktype and element type - something like S = similar(A, BlockType(...))
46-
S = BlockSparseMatrix{T,Diagonal{T,Vector{T}}}(undef, (s_axis, s_axis))
47-
Vt = similar(A, s_axis, axes(A, 2))
43+
S = BlockSparseMatrix{T,Diagonal{T,Vector{T}}}(undef, S_axes)
44+
Vt = similar(A, S_axes[2], axes(A, 2))
4845
return U, S, Vt
4946
end
5047

@@ -56,27 +53,34 @@ function MatrixAlgebraKit.initialize_output(
5653

5754
brows = eachblockaxis(axes(A, 1))
5855
bcols = eachblockaxis(axes(A, 2))
59-
s_axes = similar(brows, bmn)
56+
u_axes = similar(brows, bmn)
57+
v_axes = similar(brows, bmn)
6058

6159
# fill in values for blocks that are present
6260
bIs = collect(eachblockstoredindex(A))
6361
browIs = Int.(first.(Tuple.(bIs)))
6462
bcolIs = Int.(last.(Tuple.(bIs)))
6563
for bI in eachblockstoredindex(A)
6664
row, col = Int.(Tuple(bI))
67-
s_axes[col] = argmin(length, (brows[row], bcols[col]))
65+
len = minimum(length, (brows[row], bcols[col]))
66+
u_axes[col] = brows[row][Base.OneTo(len)]
67+
v_axes[col] = bcols[col][Base.OneTo(len)]
6868
end
6969

7070
# fill in values for blocks that aren't present, pairing them in order of occurence
7171
# this is a convention, which at least gives the expected results for blockdiagonal
7272
emptyrows = setdiff(1:bm, browIs)
7373
emptycols = setdiff(1:bn, bcolIs)
7474
for (row, col) in zip(emptyrows, emptycols)
75-
s_axes[col] = argmin(length, (brows[row], bcols[col]))
75+
len = minimum(length, (brows[row], bcols[col]))
76+
u_axes[col] = brows[row][Base.OneTo(len)]
77+
v_axes[col] = bcols[col][Base.OneTo(len)]
7678
end
7779

78-
s_axis = mortar_axis(s_axes)
79-
U, S, Vt = similar_output(svd_compact!, A, s_axis, alg)
80+
u_axis = mortar_axis(u_axes)
81+
v_axis = mortar_axis(v_axes)
82+
S_axes = (u_axis, v_axis)
83+
U, S, Vt = similar_output(svd_compact!, A, S_axes, alg)
8084

8185
# allocate output
8286
for bI in eachblockstoredindex(A)
@@ -96,12 +100,12 @@ function MatrixAlgebraKit.initialize_output(
96100
end
97101

98102
function similar_output(
99-
::typeof(svd_full!), A, s_axis::AbstractUnitRange, alg::MatrixAlgebraKit.AbstractAlgorithm
103+
::typeof(svd_full!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm
100104
)
101-
U = similar(A, axes(A, 1), s_axis)
105+
U = similar(A, axes(A, 1), S_axes[1])
102106
T = real(eltype(A))
103-
S = similar(A, T, (s_axis, axes(A, 2)))
104-
Vt = similar(A, axes(A, 2), axes(A, 2))
107+
S = similar(A, T, S_axes)
108+
Vt = similar(A, S_axes[2], axes(A, 2))
105109
return U, S, Vt
106110
end
107111

@@ -111,30 +115,31 @@ function MatrixAlgebraKit.initialize_output(
111115
bm, bn = blocksize(A)
112116

113117
brows = eachblockaxis(axes(A, 1))
114-
s_axes = similar(brows)
118+
u_axes = similar(brows)
115119

116120
# fill in values for blocks that are present
117121
bIs = collect(eachblockstoredindex(A))
118122
browIs = Int.(first.(Tuple.(bIs)))
119123
bcolIs = Int.(last.(Tuple.(bIs)))
120124
for bI in eachblockstoredindex(A)
121125
row, col = Int.(Tuple(bI))
122-
s_axes[col] = brows[row]
126+
u_axes[col] = brows[row]
123127
end
124128

125129
# fill in values for blocks that aren't present, pairing them in order of occurence
126130
# this is a convention, which at least gives the expected results for blockdiagonal
127131
emptyrows = setdiff(1:bm, browIs)
128132
emptycols = setdiff(1:bn, bcolIs)
129133
for (row, col) in zip(emptyrows, emptycols)
130-
s_axes[col] = brows[row]
134+
u_axes[col] = brows[row]
131135
end
132136
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
133-
s_axes[bn + i] = brows[emptyrows[k]]
137+
u_axes[bn + i] = brows[emptyrows[k]]
134138
end
135139

136-
s_axis = mortar_axis(s_axes)
137-
U, S, Vt = similar_output(svd_full!, A, s_axis, alg)
140+
u_axis = mortar_axis(u_axes)
141+
S_axes = (u_axis, axes(A, 2))
142+
U, S, Vt = similar_output(svd_full!, A, S_axes, alg)
138143

139144
# allocate output
140145
for bI in eachblockstoredindex(A)

0 commit comments

Comments
 (0)