@@ -21,11 +21,19 @@ function MatrixAlgebraKit.default_svd_algorithm(A::AbstractBlockSparseMatrix; kw
21
21
return BlockPermutedDiagonalAlgorithm (alg)
22
22
end
23
23
24
- # TODO : this should be replaced with a more general similar function that can handle setting
25
- # the blocktype and element type - something like S = similar(A, BlockType(...))
26
- function _similar_S (A:: AbstractBlockSparseMatrix , s_axis)
24
+ function similar_output (
25
+ :: typeof (svd_compact!),
26
+ A,
27
+ s_axis:: AbstractUnitRange ,
28
+ alg:: MatrixAlgebraKit.AbstractAlgorithm ,
29
+ )
30
+ U = similar (A, axes (A, 1 ), s_axis)
27
31
T = real (eltype (A))
28
- return BlockSparseArray {T,2,Diagonal{T,Vector{T}}} (undef, (s_axis, s_axis))
32
+ # TODO : this should be replaced with a more general similar function that can handle setting
33
+ # the blocktype and element type - something like S = similar(A, BlockType(...))
34
+ S = BlockSparseMatrix {T,Diagonal{T,Vector{T}}} (undef, (s_axis, s_axis))
35
+ Vt = similar (A, s_axis, axes (A, 2 ))
36
+ return U, S, Vt
29
37
end
30
38
31
39
function MatrixAlgebraKit. initialize_output (
@@ -34,33 +42,29 @@ function MatrixAlgebraKit.initialize_output(
34
42
bm, bn = blocksize (A)
35
43
bmn = min (bm, bn)
36
44
37
- brows = blocklengths (axes (A, 1 ))
38
- bcols = blocklengths (axes (A, 2 ))
39
- slengths = Vector {Int} (undef , bmn)
45
+ brows = eachblockaxis (axes (A, 1 ))
46
+ bcols = eachblockaxis (axes (A, 2 ))
47
+ s_axes = similar (brows , bmn)
40
48
41
49
# fill in values for blocks that are present
42
50
bIs = collect (eachblockstoredindex (A))
43
51
browIs = Int .(first .(Tuple .(bIs)))
44
52
bcolIs = Int .(last .(Tuple .(bIs)))
45
53
for bI in eachblockstoredindex (A)
46
54
row, col = Int .(Tuple (bI))
47
- nrows = brows[row]
48
- ncols = bcols[col]
49
- slengths[col] = min (nrows, ncols)
55
+ s_axes[col] = argmin (length, (brows[row], bcols[col]))
50
56
end
51
57
52
58
# fill in values for blocks that aren't present, pairing them in order of occurence
53
59
# this is a convention, which at least gives the expected results for blockdiagonal
54
60
emptyrows = setdiff (1 : bm, browIs)
55
61
emptycols = setdiff (1 : bn, bcolIs)
56
62
for (row, col) in zip (emptyrows, emptycols)
57
- slengths [col] = min ( brows[row], bcols[col])
63
+ s_axes [col] = argmin (length, ( brows[row], bcols[col]) )
58
64
end
59
65
60
- s_axis = blockedrange (slengths)
61
- U = similar (A, axes (A, 1 ), s_axis)
62
- S = _similar_S (A, s_axis)
63
- Vt = similar (A, s_axis, axes (A, 2 ))
66
+ s_axis = mortar_axis (s_axes)
67
+ U, S, Vt = similar_output (svd_compact!, A, s_axis, alg)
64
68
65
69
# allocate output
66
70
for bI in eachblockstoredindex (A)
@@ -79,40 +83,46 @@ function MatrixAlgebraKit.initialize_output(
79
83
return U, S, Vt
80
84
end
81
85
86
+ function similar_output (
87
+ :: typeof (svd_full!), A, s_axis:: AbstractUnitRange , alg:: MatrixAlgebraKit.AbstractAlgorithm
88
+ )
89
+ U = similar (A, axes (A, 1 ), s_axis)
90
+ T = real (eltype (A))
91
+ S = similar (A, T, (s_axis, axes (A, 2 )))
92
+ Vt = similar (A, axes (A, 2 ), axes (A, 2 ))
93
+ return U, S, Vt
94
+ end
95
+
82
96
function MatrixAlgebraKit. initialize_output (
83
97
:: typeof (svd_full!), A:: AbstractBlockSparseMatrix , alg:: BlockPermutedDiagonalAlgorithm
84
98
)
85
99
bm, bn = blocksize (A)
86
100
87
- brows = blocklengths (axes (A, 1 ))
88
- slengths = copy (brows)
101
+ brows = eachblockaxis (axes (A, 1 ))
102
+ s_axes = similar (brows)
89
103
90
104
# fill in values for blocks that are present
91
105
bIs = collect (eachblockstoredindex (A))
92
106
browIs = Int .(first .(Tuple .(bIs)))
93
107
bcolIs = Int .(last .(Tuple .(bIs)))
94
108
for bI in eachblockstoredindex (A)
95
109
row, col = Int .(Tuple (bI))
96
- nrows = brows[row]
97
- slengths[col] = nrows
110
+ s_axes[col] = brows[row]
98
111
end
99
112
100
113
# fill in values for blocks that aren't present, pairing them in order of occurence
101
114
# this is a convention, which at least gives the expected results for blockdiagonal
102
115
emptyrows = setdiff (1 : bm, browIs)
103
116
emptycols = setdiff (1 : bn, bcolIs)
104
117
for (row, col) in zip (emptyrows, emptycols)
105
- slengths [col] = brows[row]
118
+ s_axes [col] = brows[row]
106
119
end
107
120
for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
108
- slengths [bn + i] = brows[emptyrows[k]]
121
+ s_axes [bn + i] = brows[emptyrows[k]]
109
122
end
110
123
111
- s_axis = blockedrange (slengths)
112
- U = similar (A, axes (A, 1 ), s_axis)
113
- Tr = real (eltype (A))
114
- S = similar (A, Tr, (s_axis, axes (A, 2 )))
115
- Vt = similar (A, axes (A, 2 ), axes (A, 2 ))
124
+ s_axis = mortar_axis (s_axes)
125
+ U, S, Vt = similar_output (svd_full!, A, s_axis, alg)
116
126
117
127
# allocate output
118
128
for bI in eachblockstoredindex (A)
0 commit comments