-
Notifications
You must be signed in to change notification settings - Fork 2
Add truncation functionality for SVD #113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
b444a44
Add initial truncation implementation
lkdvos b2f448c
Add dedicated truncation type
lkdvos 792fbd3
format docstring
lkdvos 92b7060
fix imports
lkdvos b1ae30f
pass correct truncation strategy
lkdvos 1459640
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 2ea41a3
avoid `enumerate`
lkdvos 5934671
Merge branch 'truncation' of https://github.com/ITensor/BlockSparseAr…
lkdvos 3718571
Refactor truncation slightly
lkdvos dfa1c46
Bump version
lkdvos 0c9f42d
Create nicer objects for diagview
lkdvos 945c660
relax error criterion
lkdvos 17ccf8e
Add tests
lkdvos File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "BlockSparseArrays" | ||
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" | ||
authors = ["ITensor developers <[email protected]> and contributors"] | ||
version = "0.5.0" | ||
version = "0.5.1" | ||
|
||
[deps] | ||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
using MatrixAlgebraKit: TruncationStrategy, diagview, svd_trunc! | ||
|
||
function MatrixAlgebraKit.diagview(A::BlockSparseMatrix{T,Diagonal{T,Vector{T}}}) where {T} | ||
D = BlockSparseVector{T}(undef, axes(A, 1)) | ||
for I in eachblockstoredindex(A) | ||
if ==(Int.(Tuple(I))...) | ||
D[Tuple(I)[1]] = diagview(A[I]) | ||
end | ||
end | ||
return D | ||
end | ||
|
||
""" | ||
BlockPermutedDiagonalTruncationStrategy(strategy::TruncationStrategy) | ||
|
||
A wrapper for `TruncationStrategy` that implements the wrapped strategy on a block-by-block | ||
basis, which is possible if the input matrix is a block-diagonal matrix or a block permuted | ||
block-diagonal matrix. | ||
""" | ||
struct BlockPermutedDiagonalTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy | ||
strategy::T | ||
end | ||
|
||
const TBlockUSVᴴ = Tuple{ | ||
<:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix | ||
} | ||
|
||
function MatrixAlgebraKit.truncate!( | ||
::typeof(svd_trunc!), (U, S, Vᴴ)::TBlockUSVᴴ, strategy::TruncationStrategy | ||
) | ||
# TODO assert blockdiagonal | ||
return MatrixAlgebraKit.truncate!( | ||
svd_trunc!, (U, S, Vᴴ), BlockPermutedDiagonalTruncationStrategy(strategy) | ||
) | ||
end | ||
|
||
# cannot use regular slicing here: I want to slice without altering blockstructure | ||
# solution: use boolean indexing and slice the mask, effectively cheaply inverting the map | ||
function MatrixAlgebraKit.findtruncated( | ||
values::AbstractVector, strategy::BlockPermutedDiagonalTruncationStrategy | ||
) | ||
ind = MatrixAlgebraKit.findtruncated(values, strategy.strategy) | ||
indexmask = falses(length(values)) | ||
indexmask[ind] .= true | ||
return indexmask | ||
end | ||
|
||
function MatrixAlgebraKit.truncate!( | ||
::typeof(svd_trunc!), | ||
(U, S, Vᴴ)::TBlockUSVᴴ, | ||
strategy::BlockPermutedDiagonalTruncationStrategy, | ||
) | ||
indexmask = MatrixAlgebraKit.findtruncated(diagview(S), strategy) | ||
|
||
# first determine the block structure of the output to avoid having assumptions on the | ||
# data structures | ||
ax = axes(S, 1) | ||
counter = Base.Fix1(count, Base.Fix1(getindex, indexmask)) | ||
Slengths = filter!(>(0), map(counter, blocks(ax))) | ||
Sax = blockedrange(Slengths) | ||
Ũ = similar(U, axes(U, 1), Sax) | ||
S̃ = similar(S, Sax, Sax) | ||
Ṽᴴ = similar(Vᴴ, Sax, axes(Vᴴ, 2)) | ||
|
||
# then loop over the blocks and assign the data | ||
# TODO: figure out if we can presort and loop over the blocks - | ||
# for now this has issues with missing blocks | ||
bI_Us = collect(eachblockstoredindex(U)) | ||
bI_Ss = collect(eachblockstoredindex(S)) | ||
bI_Vᴴs = collect(eachblockstoredindex(Vᴴ)) | ||
|
||
I′ = 0 # number of skipped blocks that got fully truncated | ||
for I in 1:blocksize(ax, 1) | ||
b = ax[Block(I)] | ||
mask = indexmask[b] | ||
|
||
if !any(mask) | ||
I′ += 1 | ||
continue | ||
end | ||
|
||
bU_id = @something findfirst(x -> last(Tuple(x)) == Block(I), bI_Us) error( | ||
"No U-block found for $I" | ||
) | ||
bU = Tuple(bI_Us[bU_id]) | ||
Ũ[bU[1], bU[2] - Block(I′)] = view(U, bU...)[:, mask] | ||
|
||
bVᴴ_id = @something findfirst(x -> first(Tuple(x)) == Block(I), bI_Vᴴs) error( | ||
"No Vᴴ-block found for $I" | ||
) | ||
bVᴴ = Tuple(bI_Vᴴs[bVᴴ_id]) | ||
Ṽᴴ[bVᴴ[1] - Block(I′), bVᴴ[2]] = view(Vᴴ, bVᴴ...)[mask, :] | ||
|
||
bS_id = findfirst(x -> last(Tuple(x)) == Block(I), bI_Ss) | ||
if !isnothing(bS_id) | ||
bS = Tuple(bI_Ss[bS_id]) | ||
S̃[(bS .- Block(I′))...] = Diagonal(diagview(view(S, bS...))[mask]) | ||
end | ||
end | ||
|
||
return Ũ, S̃, Ṽᴴ | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.