Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] MatrixAlgebraKit decompositions #230

Draft
wants to merge 14 commits into
base: master
Choose a base branch
from
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -6,8 +6,11 @@ version = "0.14.5"
[deps]
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
TensorKitSectors = "13a9c161-d5da-41f0-bcbd-e1a08ae0647f"
@@ -31,8 +34,11 @@ Combinatorics = "1"
FiniteDifferences = "0.12"
LRUCache = "1.0.2"
LinearAlgebra = "1"
MatrixAlgebraKit = "0.1.1"
OhMyThreads = "0.8.0"
PackageExtensionCompat = "1"
Random = "1"
ScopedValues = "1.3.0"
SparseArrays = "1"
Strided = "2"
TensorKitSectors = "0.1"
9 changes: 8 additions & 1 deletion src/TensorKit.jl
Original file line number Diff line number Diff line change
@@ -100,7 +100,11 @@ using TensorOperations: TensorOperations, @tensor, @tensoropt, @ncon, ncon
using TensorOperations: IndexTuple, Index2Tuple, linearize, AbstractBackend
const TO = TensorOperations

using MatrixAlgebraKit: MatrixAlgebraKit as MAK

using LRUCache
using OhMyThreads
using ScopedValues

using TensorKitSectors
import TensorKitSectors: dim, BraidingStyle, FusionStyle, ⊠, ⊗
@@ -113,14 +117,15 @@ using Base: @boundscheck, @propagate_inbounds, @constprop,
SizeUnknown, HasLength, HasShape, IsInfinite, EltypeUnknown, HasEltype
using Base.Iterators: product, filter

using LinearAlgebra: LinearAlgebra
using LinearAlgebra: LinearAlgebra, BlasFloat
using LinearAlgebra: norm, dot, normalize, normalize!, tr,
axpy!, axpby!, lmul!, rmul!, mul!, ldiv!, rdiv!,
adjoint, adjoint!, transpose, transpose!,
lu, pinv, sylvester,
eigen, eigen!, svd, svd!,
isposdef, isposdef!, ishermitian, rank, cond,
Diagonal, Hermitian
using MatrixAlgebraKit

using SparseArrays: SparseMatrixCSC, sparse, nzrange, rowvals, nonzeros

@@ -184,6 +189,7 @@ include("spaces/vectorspaces.jl")
#-------------------------------------
# general definitions
include("tensors/abstracttensor.jl")
include("tensors/backends.jl")
include("tensors/blockiterator.jl")
include("tensors/tensor.jl")
include("tensors/adjoint.jl")
@@ -194,6 +200,7 @@ include("tensors/treetransformers.jl")
include("tensors/indexmanipulations.jl")
include("tensors/diagonal.jl")
include("tensors/truncation.jl")
include("tensors/matrixalgebrakit.jl")
include("tensors/factorizations.jl")
include("tensors/braidingtensor.jl")

43 changes: 43 additions & 0 deletions src/tensors/backends.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Scheduler implementation
# ------------------------
function select_scheduler(scheduler=OhMyThreads.Implementation.NotGiven(); kwargs...)
return if scheduler == OhMyThreads.Implementation.NotGiven() && isempty(kwargs)
Threads.nthreads() > 1 ? SerialScheduler() : DynamicScheduler()

Check warning on line 5 in src/tensors/backends.jl

Codecov / codecov/patch

src/tensors/backends.jl#L3-L5

Added lines #L3 - L5 were not covered by tests
else
OhMyThreads.Implementation._scheduler_from_userinput(scheduler; kwargs...)

Check warning on line 7 in src/tensors/backends.jl

Codecov / codecov/patch

src/tensors/backends.jl#L7

Added line #L7 was not covered by tests
end
end

"""
const blockscheduler = ScopedValue{Scheduler}(SerialScheduler())

The default scheduler used when looping over different blocks in the matrix representation of a
tensor.
For controlling this value, see also [`set_blockscheduler`](@ref) and [`with_blockscheduler`](@ref).
"""
const blockscheduler = ScopedValue{Scheduler}(SerialScheduler())

"""
with_blockscheduler(f, [scheduler]; kwargs...)

Run `f` in a scope where the `blockscheduler` is determined by `scheduler' and `kwargs...`.
"""
@inline function with_blockscheduler(f, scheduler=OhMyThreads.Implementation.NotGiven();

Check warning on line 25 in src/tensors/backends.jl

Codecov / codecov/patch

src/tensors/backends.jl#L25

Added line #L25 was not covered by tests
kwargs...)
@with blockscheduler => select_scheduler(scheduler; kwargs...) f()

Check warning on line 27 in src/tensors/backends.jl

Codecov / codecov/patch

src/tensors/backends.jl#L27

Added line #L27 was not covered by tests
end

# TODO: disable for trivial symmetry or small tensors?
default_blockscheduler(t::AbstractTensorMap) = blockscheduler[]

# MatrixAlgebraKit
# ----------------
"""
BlockAlgorithm{A,S}(alg, scheduler)

Generic wrapper for implementing block-wise algorithms.
"""
struct BlockAlgorithm{A,S} <: MatrixAlgebraKit.AbstractAlgorithm
alg::A
scheduler::S
end
25 changes: 25 additions & 0 deletions src/tensors/blockiterator.jl
Original file line number Diff line number Diff line change
@@ -13,3 +13,28 @@ Base.IteratorEltype(::BlockIterator) = Base.HasEltype()
Base.eltype(::Type{<:BlockIterator{T}}) where {T} = blocktype(T)
Base.length(iter::BlockIterator) = length(iter.structure)
Base.isdone(iter::BlockIterator, state...) = Base.isdone(iter.structure, state...)

# TODO: fast-path when structures are the same?
# TODO: do we want f(c, bs...) or f(c, bs)?
# TODO: implement scheduler
# TODO: do we prefer `blocks(t, ts...)` instead or as well?
"""
foreachblock(f, t::AbstractTensorMap, ts::AbstractTensorMap...; [scheduler])

Apply `f` to each block of `t` and the corresponding blocks of `ts`.
Optionally, `scheduler` can be used to parallelize the computation.
This function is equivalent to the following loop:

```julia
for (c, b) in blocks(t)
bs = (b, block.(ts, c)...)
f(c, bs)
end
```
"""
function foreachblock(f, t::AbstractTensorMap, ts::AbstractTensorMap...; scheduler=nothing)
foreach(blocks(t)) do (c, b)
return f(c, (b, map(Base.Fix2(block, c), ts)...))
end
return nothing
end
10 changes: 10 additions & 0 deletions src/tensors/diagonal.jl
Original file line number Diff line number Diff line change
@@ -37,6 +37,16 @@

Construct a `DiagonalTensorMap` with uninitialized data.
"""
function DiagonalTensorMap{T}(::UndefInitializer, V::TensorMapSpace) where {T}
(numin(V) == numout(V) == 1 && domain(V) == codomain(V)) ||

Check warning on line 41 in src/tensors/diagonal.jl

Codecov / codecov/patch

src/tensors/diagonal.jl#L40-L41

Added lines #L40 - L41 were not covered by tests
throw(SpaceMismatch("DiagonalTensorMap requires a space with equal domain and codomain and 2 indices"))
return DiagonalTensorMap{T}(undef, domain(V))

Check warning on line 43 in src/tensors/diagonal.jl

Codecov / codecov/patch

src/tensors/diagonal.jl#L43

Added line #L43 was not covered by tests
end
function DiagonalTensorMap{T}(::UndefInitializer, V::ProductSpace) where {T}
length(V) == 1 ||

Check warning on line 46 in src/tensors/diagonal.jl

Codecov / codecov/patch

src/tensors/diagonal.jl#L45-L46

Added lines #L45 - L46 were not covered by tests
throw(DimensionMismatch("length(V) = $(length(V)) is not compatible with the space $V"))
return DiagonalTensorMap{T}(undef, only(V))

Check warning on line 48 in src/tensors/diagonal.jl

Codecov / codecov/patch

src/tensors/diagonal.jl#L48

Added line #L48 was not covered by tests
end
function DiagonalTensorMap{T}(::UndefInitializer, V::S) where {T,S<:IndexSpace}
return DiagonalTensorMap{T,S,Vector{T}}(undef, V)
end
98 changes: 54 additions & 44 deletions src/tensors/factorizations.jl
Original file line number Diff line number Diff line change
@@ -531,6 +531,13 @@ function tsvd!(t::AdjointTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD())
end

# implementation dispatches on algorithm
function _tsvd!(t::TensorMap{<:BlasFloat}, alg::Union{SVD,SDD},
::NoTruncation, p::Real=2)
scheduler = default_blockscheduler(t)
svd_alg = alg isa SDD ? LAPACK_DivideAndConquer() : LAPACK_QRIteration()
return MatrixAlgebraKit.svd_compact!(t; alg=BlockAlgorithm(svd_alg, scheduler))...,
zero(real(scalartype(t)))
end
function _tsvd!(t::TensorMap{<:RealOrComplexFloat}, alg::Union{SVD,SDD},
trunc::TruncationScheme, p::Real=2)
# early return
@@ -617,50 +624,53 @@ function LinearAlgebra.eigvals!(t::AdjointTensorMap{<:RealOrComplexFloat}; kwarg
for (c, b) in blocks(t))
end

function eigh!(t::TensorMap{<:RealOrComplexFloat})
InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eigh!)
domain(t) == codomain(t) ||
throw(SpaceMismatch("`eigh!` requires domain and codomain to be the same"))

T = scalartype(t)
I = sectortype(t)
S = spacetype(t)
dims = SectorDict{I,Int}(c => size(b, 1) for (c, b) in blocks(t))
W = S(dims)

Tr = real(T)
A = similarstoragetype(t, Tr)
D = DiagonalTensorMap{Tr,S,A}(undef, W)
V = similar(t, domain(t) ← W)
for (c, b) in blocks(t)
values, vectors = MatrixAlgebra.eigh!(b)
copy!(block(D, c), Diagonal(values))
copy!(block(V, c), vectors)
end
return D, V
end

function eig!(t::TensorMap{<:RealOrComplexFloat}; kwargs...)
domain(t) == codomain(t) ||
throw(SpaceMismatch("`eig!` requires domain and codomain to be the same"))

T = scalartype(t)
I = sectortype(t)
S = spacetype(t)
dims = SectorDict{I,Int}(c => size(b, 1) for (c, b) in blocks(t))
W = S(dims)

Tc = complex(T)
A = similarstoragetype(t, Tc)
D = DiagonalTensorMap{Tc,S,A}(undef, W)
V = similar(t, Tc, domain(t) ← W)
for (c, b) in blocks(t)
values, vectors = MatrixAlgebra.eig!(b; kwargs...)
copy!(block(D, c), Diagonal(values))
copy!(block(V, c), vectors)
end
return D, V
end
eigh!(t::TensorMap{<:RealOrComplexFloat}) = eigh_full!(t)
eig!(t::TensorMap{<:RealOrComplexFloat}) = eig_full!(t)

# function eigh!(t::TensorMap{<:RealOrComplexFloat})
# InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eigh!)
# domain(t) == codomain(t) ||
# throw(SpaceMismatch("`eigh!` requires domain and codomain to be the same"))

# T = scalartype(t)
# I = sectortype(t)
# S = spacetype(t)
# dims = SectorDict{I,Int}(c => size(b, 1) for (c, b) in blocks(t))
# W = S(dims)

# Tr = real(T)
# A = similarstoragetype(t, Tr)
# D = DiagonalTensorMap{Tr,S,A}(undef, W)
# V = similar(t, domain(t) ← W)
# for (c, b) in blocks(t)
# values, vectors = MatrixAlgebra.eigh!(b)
# copy!(block(D, c), Diagonal(values))
# copy!(block(V, c), vectors)
# end
# return D, V
# end

# function eig!(t::TensorMap{<:RealOrComplexFloat}; kwargs...)
# domain(t) == codomain(t) ||
# throw(SpaceMismatch("`eig!` requires domain and codomain to be the same"))

# T = scalartype(t)
# I = sectortype(t)
# S = spacetype(t)
# dims = SectorDict{I,Int}(c => size(b, 1) for (c, b) in blocks(t))
# W = S(dims)

# Tc = complex(T)
# A = similarstoragetype(t, Tc)
# D = DiagonalTensorMap{Tc,S,A}(undef, W)
# V = similar(t, Tc, domain(t) ← W)
# for (c, b) in blocks(t)
# values, vectors = MatrixAlgebra.eig!(b; kwargs...)
# copy!(block(D, c), Diagonal(values))
# copy!(block(V, c), vectors)
# end
# return D, V
# end

#--------------------------------------------------#
# Checks for hermiticity and positive definiteness #
Loading
Loading