Skip to content

Commit bb62929

Browse files
authored
Add support for block sparse QR decomposition (#117)
1 parent 2120b7a commit bb62929

File tree

4 files changed

+251
-2
lines changed

4 files changed

+251
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.6.1"
4+
version = "0.6.2"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/BlockSparseArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,6 @@ include("BlockArraysSparseArraysBaseExt/BlockArraysSparseArraysBaseExt.jl")
4646
# factorizations
4747
include("factorizations/svd.jl")
4848
include("factorizations/truncation.jl")
49+
include("factorizations/qr.jl")
4950

5051
end

src/factorizations/qr.jl

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
using MatrixAlgebraKit: MatrixAlgebraKit, qr_compact!, qr_full!
2+
3+
# TODO: this is a hardcoded for now to get around this function not being defined in the
4+
# type domain
5+
function default_blocksparse_qr_algorithm(A::AbstractMatrix; kwargs...)
6+
blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} ||
7+
error("unsupported type: $(blocktype(A))")
8+
alg = MatrixAlgebraKit.LAPACK_HouseholderQR(; kwargs...)
9+
return BlockPermutedDiagonalAlgorithm(alg)
10+
end
11+
function MatrixAlgebraKit.default_algorithm(
12+
::typeof(qr_compact!), A::AbstractBlockSparseMatrix; kwargs...
13+
)
14+
return default_blocksparse_qr_algorithm(A; kwargs...)
15+
end
16+
function MatrixAlgebraKit.default_algorithm(
17+
::typeof(qr_full!), A::AbstractBlockSparseMatrix; kwargs...
18+
)
19+
return default_blocksparse_qr_algorithm(A; kwargs...)
20+
end
21+
22+
function similar_output(
23+
::typeof(qr_compact!), A, R_axis, alg::MatrixAlgebraKit.AbstractAlgorithm
24+
)
25+
Q = similar(A, axes(A, 1), R_axis)
26+
R = similar(A, R_axis, axes(A, 2))
27+
return Q, R
28+
end
29+
30+
function similar_output(
31+
::typeof(qr_full!), A, R_axis, alg::MatrixAlgebraKit.AbstractAlgorithm
32+
)
33+
Q = similar(A, axes(A, 1), R_axis)
34+
R = similar(A, R_axis, axes(A, 2))
35+
return Q, R
36+
end
37+
38+
function MatrixAlgebraKit.initialize_output(
39+
::typeof(qr_compact!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
40+
)
41+
bm, bn = blocksize(A)
42+
bmn = min(bm, bn)
43+
44+
brows = eachblockaxis(axes(A, 1))
45+
bcols = eachblockaxis(axes(A, 2))
46+
r_axes = similar(brows, bmn)
47+
48+
# fill in values for blocks that are present
49+
bIs = collect(eachblockstoredindex(A))
50+
browIs = Int.(first.(Tuple.(bIs)))
51+
bcolIs = Int.(last.(Tuple.(bIs)))
52+
for bI in eachblockstoredindex(A)
53+
row, col = Int.(Tuple(bI))
54+
len = minimum(length, (brows[row], bcols[col]))
55+
r_axes[col] = brows[row][Base.OneTo(len)]
56+
end
57+
58+
# fill in values for blocks that aren't present, pairing them in order of occurence
59+
# this is a convention, which at least gives the expected results for blockdiagonal
60+
emptyrows = setdiff(1:bm, browIs)
61+
emptycols = setdiff(1:bn, bcolIs)
62+
for (row, col) in zip(emptyrows, emptycols)
63+
len = minimum(length, (brows[row], bcols[col]))
64+
r_axes[col] = brows[row][Base.OneTo(len)]
65+
end
66+
67+
r_axis = mortar_axis(r_axes)
68+
Q, R = similar_output(qr_compact!, A, r_axis, alg)
69+
70+
# allocate output
71+
for bI in eachblockstoredindex(A)
72+
brow, bcol = Tuple(bI)
73+
Q[brow, bcol], R[bcol, bcol] = MatrixAlgebraKit.initialize_output(
74+
qr_compact!, @view!(A[bI]), alg.alg
75+
)
76+
end
77+
78+
# allocate output for blocks that aren't present -- do we also fill identities here?
79+
for (row, col) in zip(emptyrows, emptycols)
80+
@view!(Q[Block(row, col)])
81+
end
82+
83+
return Q, R
84+
end
85+
86+
function MatrixAlgebraKit.initialize_output(
87+
::typeof(qr_full!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
88+
)
89+
bm, bn = blocksize(A)
90+
91+
brows = eachblockaxis(axes(A, 1))
92+
r_axes = copy(brows)
93+
94+
# fill in values for blocks that are present
95+
bIs = collect(eachblockstoredindex(A))
96+
browIs = Int.(first.(Tuple.(bIs)))
97+
bcolIs = Int.(last.(Tuple.(bIs)))
98+
for bI in eachblockstoredindex(A)
99+
row, col = Int.(Tuple(bI))
100+
r_axes[col] = brows[row]
101+
end
102+
103+
# fill in values for blocks that aren't present, pairing them in order of occurence
104+
# this is a convention, which at least gives the expected results for blockdiagonal
105+
emptyrows = setdiff(1:bm, browIs)
106+
emptycols = setdiff(1:bn, bcolIs)
107+
for (row, col) in zip(emptyrows, emptycols)
108+
r_axes[col] = brows[row]
109+
end
110+
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
111+
r_axes[bn + i] = brows[emptyrows[k]]
112+
end
113+
114+
r_axis = mortar_axis(r_axes)
115+
Q, R = similar_output(qr_full!, A, r_axis, alg)
116+
117+
# allocate output
118+
for bI in eachblockstoredindex(A)
119+
brow, bcol = Tuple(bI)
120+
Q[brow, bcol], R[bcol, bcol] = MatrixAlgebraKit.initialize_output(
121+
qr_full!, @view!(A[bI]), alg.alg
122+
)
123+
end
124+
125+
# allocate output for blocks that aren't present -- do we also fill identities here?
126+
for (row, col) in zip(emptyrows, emptycols)
127+
@view!(Q[Block(row, col)])
128+
end
129+
# also handle extra rows/cols
130+
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
131+
@view!(Q[Block(emptyrows[k], bn + i)])
132+
end
133+
134+
return Q, R
135+
end
136+
137+
function MatrixAlgebraKit.check_input(
138+
::typeof(qr_compact!), A::AbstractBlockSparseMatrix, QR
139+
)
140+
Q, R = QR
141+
@assert isa(Q, AbstractBlockSparseMatrix) && isa(R, AbstractBlockSparseMatrix)
142+
@assert eltype(A) == eltype(Q) == eltype(R)
143+
@assert axes(A, 1) == axes(Q, 1) && axes(A, 2) == axes(R, 2)
144+
@assert axes(Q, 2) == axes(R, 1)
145+
146+
return nothing
147+
end
148+
149+
function MatrixAlgebraKit.check_input(::typeof(qr_full!), A::AbstractBlockSparseMatrix, QR)
150+
Q, R = QR
151+
@assert isa(Q, AbstractBlockSparseMatrix) && isa(R, AbstractBlockSparseMatrix)
152+
@assert eltype(A) == eltype(Q) == eltype(R)
153+
@assert axes(A, 1) == axes(Q, 1) && axes(A, 2) == axes(R, 2)
154+
@assert axes(Q, 2) == axes(R, 1)
155+
156+
return nothing
157+
end
158+
159+
function MatrixAlgebraKit.qr_compact!(
160+
A::AbstractBlockSparseMatrix, QR, alg::BlockPermutedDiagonalAlgorithm
161+
)
162+
MatrixAlgebraKit.check_input(qr_compact!, A, QR)
163+
Q, R = QR
164+
165+
# do decomposition on each block
166+
for bI in eachblockstoredindex(A)
167+
brow, bcol = Tuple(bI)
168+
qr = (@view!(Q[brow, bcol]), @view!(R[bcol, bcol]))
169+
qr′ = qr_compact!(@view!(A[bI]), qr, alg.alg)
170+
@assert qr === qr′ "qr_compact! might not be in-place"
171+
end
172+
173+
# fill in identities for blocks that aren't present
174+
bIs = collect(eachblockstoredindex(A))
175+
browIs = Int.(first.(Tuple.(bIs)))
176+
bcolIs = Int.(last.(Tuple.(bIs)))
177+
emptyrows = setdiff(1:blocksize(A, 1), browIs)
178+
emptycols = setdiff(1:blocksize(A, 2), bcolIs)
179+
# needs copyto! instead because size(::LinearAlgebra.I) doesn't work
180+
# Q[Block(row, col)] = LinearAlgebra.I
181+
for (row, col) in zip(emptyrows, emptycols)
182+
copyto!(@view!(Q[Block(row, col)]), LinearAlgebra.I)
183+
end
184+
185+
return QR
186+
end
187+
188+
function MatrixAlgebraKit.qr_full!(
189+
A::AbstractBlockSparseMatrix, QR, alg::BlockPermutedDiagonalAlgorithm
190+
)
191+
MatrixAlgebraKit.check_input(qr_full!, A, QR)
192+
Q, R = QR
193+
194+
# do decomposition on each block
195+
for bI in eachblockstoredindex(A)
196+
brow, bcol = Tuple(bI)
197+
qr = (@view!(Q[brow, bcol]), @view!(R[bcol, bcol]))
198+
qr′ = qr_full!(@view!(A[bI]), qr, alg.alg)
199+
@assert qr === qr′ "qr_full! might not be in-place"
200+
end
201+
202+
# fill in identities for blocks that aren't present
203+
bIs = collect(eachblockstoredindex(A))
204+
browIs = Int.(first.(Tuple.(bIs)))
205+
bcolIs = Int.(last.(Tuple.(bIs)))
206+
emptyrows = setdiff(1:blocksize(A, 1), browIs)
207+
emptycols = setdiff(1:blocksize(A, 2), bcolIs)
208+
# needs copyto! instead because size(::LinearAlgebra.I) doesn't work
209+
# Q[Block(row, col)] = LinearAlgebra.I
210+
for (row, col) in zip(emptyrows, emptycols)
211+
copyto!(@view!(Q[Block(row, col)]), LinearAlgebra.I)
212+
end
213+
214+
# also handle extra rows/cols
215+
bn = blocksize(A, 2)
216+
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
217+
copyto!(@view!(Q[Block(emptyrows[k], bn + i)]), LinearAlgebra.I)
218+
end
219+
220+
return QR
221+
end

test/test_factorizations.jl

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar
22
using BlockSparseArrays: BlockSparseArray, BlockDiagonal, eachblockstoredindex
3-
using MatrixAlgebraKit: svd_compact, svd_full, svd_trunc, truncrank, trunctol
3+
using MatrixAlgebraKit:
4+
qr_compact, qr_full, svd_compact, svd_full, svd_trunc, truncrank, trunctol
45
using LinearAlgebra: LinearAlgebra
56
using Random: Random
67
using Test: @inferred, @testset, @test
@@ -154,3 +155,29 @@ end
154155
@test (V1ᴴ * V1ᴴ' LinearAlgebra.I)
155156
end
156157
end
158+
159+
@testset "qr_compact" for T in (Float32, Float64, ComplexF32, ComplexF64)
160+
for i in [1, 2], j in [1, 2], k in [1, 2], l in [1, 2]
161+
A = BlockSparseArray{T}(undef, ([i, j], [k, l]))
162+
A[Block(1, 1)] = randn(T, i, k)
163+
A[Block(2, 2)] = randn(T, j, l)
164+
Q, R = qr_compact(A)
165+
@test Matrix(Q'Q) LinearAlgebra.I
166+
@test A Q * R
167+
end
168+
end
169+
170+
@testset "qr_full" for T in (Float32, Float64, ComplexF32, ComplexF64)
171+
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
172+
A = BlockSparseArray{T}(undef, ([i, j], [k, l]))
173+
A[Block(1, 1)] = randn(T, i, k)
174+
A[Block(2, 2)] = randn(T, j, l)
175+
Q, R = qr_full(A)
176+
Q′, R′ = qr_full(Matrix(A))
177+
@test size(Q) == size(Q′)
178+
@test size(R) == size(R′)
179+
@test Matrix(Q'Q) LinearAlgebra.I
180+
@test Matrix(Q * Q') LinearAlgebra.I
181+
@test A Q * R
182+
end
183+
end

0 commit comments

Comments
 (0)