Skip to content

Add truncation functionality for SVD #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 9, 2025
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockSparseArrays"
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.5.0"
version = "0.5.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
1 change: 1 addition & 0 deletions src/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,6 @@ include("BlockArraysSparseArraysBaseExt/BlockArraysSparseArraysBaseExt.jl")

# factorizations
include("factorizations/svd.jl")
include("factorizations/truncation.jl")

end
3 changes: 2 additions & 1 deletion src/factorizations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ using MatrixAlgebraKit: MatrixAlgebraKit, svd_compact!, svd_full!
BlockPermutedDiagonalAlgorithm(A::MatrixAlgebraKit.AbstractAlgorithm)

A wrapper for `MatrixAlgebraKit.AbstractAlgorithm` that implements the wrapped algorithm on
a block-by-block basis, which is possible if the input matrix is a block-diagonal matrix or a block permuted block-diagonal matrix.
a block-by-block basis, which is possible if the input matrix is a block-diagonal matrix or
a block permuted block-diagonal matrix.
"""
struct BlockPermutedDiagonalAlgorithm{A<:MatrixAlgebraKit.AbstractAlgorithm} <:
MatrixAlgebraKit.AbstractAlgorithm
Expand Down
102 changes: 102 additions & 0 deletions src/factorizations/truncation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
using MatrixAlgebraKit: TruncationStrategy, diagview, svd_trunc!

function MatrixAlgebraKit.diagview(A::BlockSparseMatrix{T,Diagonal{T,Vector{T}}}) where {T}
D = BlockSparseVector{T}(undef, axes(A, 1))
for I in eachblockstoredindex(A)
if ==(Int.(Tuple(I))...)
D[Tuple(I)[1]] = diagview(A[I])
end
end
return D
end

"""
BlockPermutedDiagonalTruncationStrategy(strategy::TruncationStrategy)

A wrapper for `TruncationStrategy` that implements the wrapped strategy on a block-by-block
basis, which is possible if the input matrix is a block-diagonal matrix or a block permuted
block-diagonal matrix.
"""
struct BlockPermutedDiagonalTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy
strategy::T
end

const TBlockUSVᴴ = Tuple{
<:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix
}

function MatrixAlgebraKit.truncate!(
::typeof(svd_trunc!), (U, S, Vᴴ)::TBlockUSVᴴ, strategy::TruncationStrategy
)
# TODO assert blockdiagonal
return MatrixAlgebraKit.truncate!(
svd_trunc!, (U, S, Vᴴ), BlockPermutedDiagonalTruncationStrategy(strategy)
)
end

# cannot use regular slicing here: I want to slice without altering blockstructure
# solution: use boolean indexing and slice the mask, effectively cheaply inverting the map
function MatrixAlgebraKit.findtruncated(
values::AbstractVector, strategy::BlockPermutedDiagonalTruncationStrategy
)
ind = MatrixAlgebraKit.findtruncated(values, strategy.strategy)
indexmask = falses(length(values))
indexmask[ind] .= true
return indexmask
end

function MatrixAlgebraKit.truncate!(
::typeof(svd_trunc!),
(U, S, Vᴴ)::TBlockUSVᴴ,
strategy::BlockPermutedDiagonalTruncationStrategy,
)
indexmask = MatrixAlgebraKit.findtruncated(diagview(S), strategy)

# first determine the block structure of the output to avoid having assumptions on the
# data structures
ax = axes(S, 1)
counter = Base.Fix1(count, Base.Fix1(getindex, indexmask))
Slengths = filter!(>(0), map(counter, blocks(ax)))
Sax = blockedrange(Slengths)
Ũ = similar(U, axes(U, 1), Sax)
S̃ = similar(S, Sax, Sax)
Ṽᴴ = similar(Vᴴ, Sax, axes(Vᴴ, 2))

# then loop over the blocks and assign the data
# TODO: figure out if we can presort and loop over the blocks -
# for now this has issues with missing blocks
bI_Us = collect(eachblockstoredindex(U))
bI_Ss = collect(eachblockstoredindex(S))
bI_Vᴴs = collect(eachblockstoredindex(Vᴴ))

I′ = 0 # number of skipped blocks that got fully truncated
for I in 1:blocksize(ax, 1)
b = ax[Block(I)]
mask = indexmask[b]

if !any(mask)
I′ += 1
continue
end

bU_id = @something findfirst(x -> last(Tuple(x)) == Block(I), bI_Us) error(
"No U-block found for $I"
)
bU = Tuple(bI_Us[bU_id])
Ũ[bU[1], bU[2] - Block(I′)] = view(U, bU...)[:, mask]

bVᴴ_id = @something findfirst(x -> first(Tuple(x)) == Block(I), bI_Vᴴs) error(
"No Vᴴ-block found for $I"
)
bVᴴ = Tuple(bI_Vᴴs[bVᴴ_id])
Ṽᴴ[bVᴴ[1] - Block(I′), bVᴴ[2]] = view(Vᴴ, bVᴴ...)[mask, :]

bS_id = findfirst(x -> last(Tuple(x)) == Block(I), bI_Ss)
if !isnothing(bS_id)
bS = Tuple(bI_Ss[bS_id])
S̃[(bS .- Block(I′))...] = Diagonal(diagview(view(S, bS...))[mask])
end
end

return Ũ, S̃, Ṽᴴ
end
73 changes: 72 additions & 1 deletion test/test_factorizations.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar
using BlockSparseArrays: BlockSparseArray, BlockDiagonal, eachblockstoredindex
using MatrixAlgebraKit: svd_compact, svd_full
using MatrixAlgebraKit: svd_compact, svd_full, svd_trunc, truncrank, trunctol
using LinearAlgebra: LinearAlgebra
using Random: Random
using Test: @inferred, @testset, @test
Expand Down Expand Up @@ -83,3 +83,74 @@ end
usv = svd_full(c)
@test test_svd(c, usv; full=true)
end

# svd_trunc!
# ----------

@testset "svd_trunc ($m, $n) BlockSparseMatri{$T}" for ((m, n), T) in test_params
a = BlockSparseArray{T}(undef, m, n)

# test blockdiagonal
for i in LinearAlgebra.diagind(blocks(a))
I = CartesianIndices(blocks(a))[i]
a[Block(I.I...)] = rand(T, size(blocks(a)[i]))
end

minmn = min(size(a)...)
r = max(1, minmn - 2)
trunc = truncrank(r)

U1, S1, V1ᴴ = svd_trunc(a; trunc)
U2, S2, V2ᴴ = svd_trunc(Matrix(a); trunc)
@test size(U1) == size(U2)
@test size(S1) == size(S2)
@test size(V1ᴴ) == size(V2ᴴ)
@test Matrix(U1 * S1 * V1ᴴ) ≈ U2 * S2 * V2ᴴ

@test (U1' * U1 ≈ LinearAlgebra.I)
@test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra.I)

atol = minimum(LinearAlgebra.diag(S1)) + 10 * eps(real(T))
trunc = trunctol(atol)

U1, S1, V1ᴴ = svd_trunc(a; trunc)
U2, S2, V2ᴴ = svd_trunc(Matrix(a); trunc)
@test size(U1) == size(U2)
@test size(S1) == size(S2)
@test size(V1ᴴ) == size(V2ᴴ)
@test Matrix(U1 * S1 * V1ᴴ) ≈ U2 * S2 * V2ᴴ

@test (U1' * U1 ≈ LinearAlgebra.I)
@test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra.I)

# test permuted blockdiagonal
perm = Random.randperm(length(m))
b = a[Block.(perm), Block.(1:length(n))]
for trunc in (truncrank(r), trunctol(atol))
U1, S1, V1ᴴ = svd_trunc(b; trunc)
U2, S2, V2ᴴ = svd_trunc(Matrix(b); trunc)
@test size(U1) == size(U2)
@test size(S1) == size(S2)
@test size(V1ᴴ) == size(V2ᴴ)
@test Matrix(U1 * S1 * V1ᴴ) ≈ U2 * S2 * V2ᴴ

@test (U1' * U1 ≈ LinearAlgebra.I)
@test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra.I)
end

# test permuted blockdiagonal with missing row/col
I_removed = rand(eachblockstoredindex(b))
c = copy(b)
delete!(blocks(c).storage, CartesianIndex(Int.(Tuple(I_removed))))
for trunc in (truncrank(r), trunctol(atol))
U1, S1, V1ᴴ = svd_trunc(c; trunc)
U2, S2, V2ᴴ = svd_trunc(Matrix(c); trunc)
@test size(U1) == size(U2)
@test size(S1) == size(S2)
@test size(V1ᴴ) == size(V2ᴴ)
@test Matrix(U1 * S1 * V1ᴴ) ≈ U2 * S2 * V2ᴴ

@test (U1' * U1 ≈ LinearAlgebra.I)
@test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra.I)
end
end
Loading