Skip to content

Commit 4326cdb

Browse files
committed
Add more matrixalgebra methods
1 parent 4bcb517 commit 4326cdb

File tree

3 files changed

+187
-25
lines changed

3 files changed

+187
-25
lines changed

src/TensorKit.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,15 @@ using Base: @boundscheck, @propagate_inbounds, @constprop,
117117
SizeUnknown, HasLength, HasShape, IsInfinite, EltypeUnknown, HasEltype
118118
using Base.Iterators: product, filter
119119

120-
using LinearAlgebra: LinearAlgebra
120+
using LinearAlgebra: LinearAlgebra, BlasFloat
121121
using LinearAlgebra: norm, dot, normalize, normalize!, tr,
122122
axpy!, axpby!, lmul!, rmul!, mul!, ldiv!, rdiv!,
123123
adjoint, adjoint!, transpose, transpose!,
124124
lu, pinv, sylvester,
125125
eigen, eigen!, svd, svd!,
126126
isposdef, isposdef!, ishermitian, rank, cond,
127127
Diagonal, Hermitian
128+
using MatrixAlgebraKit
128129

129130
using SparseArrays: SparseMatrixCSC, sparse, nzrange, rowvals, nonzeros
130131

src/tensors/blockiterator.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ end
3434
"""
3535
function foreachblock(f, t::AbstractTensorMap, ts::AbstractTensorMap...; scheduler=nothing)
3636
foreach(blocks(t)) do (c, b)
37-
return f(c, (b, block.(ts, c)...))
37+
return f(c, (b, map(Base.Fix2(block, c), ts)...))
3838
end
3939
return nothing
4040
end

src/tensors/matrixalgebrakit.jl

+184-23
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,210 @@
1-
function MAK.copy_input(::typeof(MAK.eig_full), t::AbstractTensorMap)
2-
return copy_oftype(t, factorisation_scalartype(MAK.eig_full!, t))
1+
# Generic
2+
# -------
3+
for f in (:eig_full, :eig_vals, :eig_trunc, :eigh_full, :eigh_vals, :eigh_trunc, :svd_full,
4+
:svd_compact, :svd_vals, :svd_trunc)
5+
@eval function MatrixAlgebraKit.copy_input(::typeof($f),
6+
t::AbstractTensorMap{<:BlasFloat})
7+
T = factorisation_scalartype($f, t)
8+
return copy_oftype(t, T)
9+
end
10+
end
11+
12+
# function factorisation_scalartype(::typeof(MAK.eig_full!), t::AbstractTensorMap)
13+
# T = scalartype(t)
14+
# return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T)))))
15+
# end
16+
17+
# Singular value decomposition
18+
# ----------------------------
19+
function MatrixAlgebraKit.check_input(::typeof(svd_full!), t::AbstractTensorMap, (U, S, Vᴴ))
20+
V_cod = fuse(codomain(t))
21+
V_dom = fuse(domain(t))
22+
23+
(U isa AbstractTensorMap &&
24+
scalartype(U) == scalartype(t) &&
25+
space(U) == (codomain(t) V_cod)) ||
26+
throw(ArgumentError("`svd_full!` requires unitary tensor U with same `scalartype`"))
27+
(S isa AbstractTensorMap &&
28+
scalartype(S) == real(scalartype(t)) &&
29+
space(S) == (V_cod V_dom)) ||
30+
throw(ArgumentError("`svd_full!` requires rectangular tensor S with real `scalartype`"))
31+
(Vᴴ isa AbstractTensorMap &&
32+
scalartype(Vᴴ) == scalartype(t) &&
33+
space(Vᴴ) == (V_dom domain(t))) ||
34+
throw(ArgumentError("`svd_full!` requires unitary tensor Vᴴ with same `scalartype`"))
35+
36+
return nothing
37+
end
38+
39+
function MatrixAlgebraKit.check_input(::typeof(svd_compact!), t::AbstractTensorMap,
40+
(U, S, Vᴴ))
41+
V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t)))
42+
43+
(U isa AbstractTensorMap &&
44+
scalartype(U) == scalartype(t) &&
45+
space(U) == (codomain(t) V_cod)) ||
46+
throw(ArgumentError("`svd_compact!` requires isometric tensor U with same `scalartype`"))
47+
(S isa DiagonalTensorMap &&
48+
scalartype(S) == real(scalartype(t)) &&
49+
space(S) == (V_cod V_dom)) ||
50+
throw(ArgumentError("`svd_compact!` requires diagonal tensor S with real `scalartype`"))
51+
(Vᴴ isa AbstractTensorMap &&
52+
scalartype(Vᴴ) == scalartype(t) &&
53+
space(Vᴴ) == (V_dom domain(t))) ||
54+
throw(ArgumentError("`svd_compact!` requires isometric tensor Vᴴ with same `scalartype`"))
55+
56+
return nothing
57+
end
58+
59+
# TODO: svd_vals
60+
61+
function MatrixAlgebraKit.initialize_output(::typeof(svd_full!), t::AbstractTensorMap,
62+
::MatrixAlgebraKit.AbstractAlgorithm)
63+
V_cod = fuse(codomain(t))
64+
V_dom = fuse(domain(t))
65+
U = similar(t, domain(t) V_cod)
66+
S = similar(t, real(scalartype(t)), V_cod V_dom)
67+
Vᴴ = similar(t, domain(t) V_dom)
68+
return U, S, Vᴴ
69+
end
70+
71+
function MatrixAlgebraKit.initialize_output(::typeof(svd_compact!), t::AbstractTensorMap,
72+
::MatrixAlgebraKit.AbstractAlgorithm)
73+
V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t)))
74+
U = similar(t, domain(t) V_cod)
75+
S = DiagonalTensorMap{real(scalartype(t))}(undef, V_cod V_dom)
76+
Vᴴ = similar(t, domain(t) V_dom)
77+
return U, S, Vᴴ
78+
end
79+
80+
# TODO: svd_vals
81+
82+
function MatrixAlgebraKit.svd_full!(t::AbstractTensorMap, (U, S, Vᴴ),
83+
alg::BlockAlgorithm)
84+
MatrixAlgebraKit.check_input(svd_full!, t, (U, S, Vᴴ))
85+
86+
foreachblock(t, U, S, Vᴴ; alg.scheduler) do _, (b, u, s, vᴴ)
87+
if isempty(b) # TODO: remove once MatrixAlgebraKit supports empty matrices
88+
one!(length(u) > 0 ? u : vᴴ)
89+
zerovector!(s)
90+
else
91+
u′, s′, vᴴ′ = MatrixAlgebraKit.svd_full!(b, (u, s, vᴴ), alg.alg)
92+
# deal with the case where the output is not the same as the input
93+
u === u′ || copyto!(u, u′)
94+
s === s′ || copyto!(s, s′)
95+
vᴴ === vᴴ′ || copyto!(vᴴ, vᴴ′)
96+
end
97+
return nothing
98+
end
99+
100+
return U, S, Vᴴ
101+
end
102+
103+
function MatrixAlgebraKit.svd_compact!(t::AbstractTensorMap, (U, S, Vᴴ),
104+
alg::BlockAlgorithm)
105+
MatrixAlgebraKit.check_input(svd_compact!, t, (U, S, Vᴴ))
106+
107+
foreachblock(t, U, S, Vᴴ; alg.scheduler) do _, (b, u, s, vᴴ)
108+
u′, s′, vᴴ′ = svd_compact!(b, (u, s, vᴴ), alg.alg)
109+
# deal with the case where the output is not the same as the input
110+
u === u′ || copyto!(u, u′)
111+
s === s′ || copyto!(s, s′)
112+
vᴴ === vᴴ′ || copyto!(vᴴ, vᴴ′)
113+
return nothing
114+
end
115+
116+
return U, S, Vᴴ
117+
end
118+
119+
function MatrixAlgebraKit.default_svd_algorithm(t::AbstractTensorMap{<:BlasFloat};
120+
scheduler=default_blockscheduler(t),
121+
kwargs...)
122+
return BlockAlgorithm(LAPACK_DivideAndConquer(; kwargs...), scheduler)
3123
end
4124

5-
function factorisation_scalartype(::typeof(MAK.eig_full!), t::AbstractTensorMap)
6-
T = complex(scalartype(t))
7-
return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T)))))
125+
# Eigenvalue decomposition
126+
# ------------------------
127+
function MatrixAlgebraKit.check_input(::typeof(eigh_full!), t::AbstractTensorMap, (D, V))
128+
domain(t) == codomain(t) ||
129+
throw(ArgumentError("Eigenvalue decomposition requires square input tensor"))
130+
131+
V_D = fuse(domain(t))
132+
133+
(D isa DiagonalTensorMap &&
134+
scalartype(D) == real(scalartype(t)) &&
135+
V_D == space(D, 1)) ||
136+
throw(ArgumentError("`eigh_full!` requires diagonal tensor D with isomorphic domain and real `scalartype`"))
137+
138+
V isa AbstractTensorMap &&
139+
scalartype(V) == scalartype(t) &&
140+
space(V) == (codomain(t) V_D) ||
141+
throw(ArgumentError("`eigh_full!` requires square tensor V with isomorphic domain and equal `scalartype`"))
142+
143+
return nothing
8144
end
9145

10-
function MAK.check_input(::typeof(MAK.eig_full!), t::AbstractTensorMap, (D, V))
146+
function MatrixAlgebraKit.check_input(::typeof(eig_full!), t::AbstractTensorMap, (D, V))
11147
domain(t) == codomain(t) ||
12148
throw(ArgumentError("Eigenvalue decomposition requires square input tensor"))
13149
Tc = complex(scalartype(t))
150+
V_D = fuse(domain(t))
14151

15152
(D isa DiagonalTensorMap &&
16153
scalartype(D) == Tc &&
17-
fuse(domain(t)) == space(D, 1)) ||
154+
V_D == space(D, 1)) ||
18155
throw(ArgumentError("`eig_full!` requires diagonal tensor D with isomorphic domain and complex `scalartype`"))
19156

20157
V isa AbstractTensorMap &&
21158
scalartype(V) == Tc &&
22-
space(V) == (codomain(t) codomain(D)) ||
159+
space(V) == (codomain(t) V_D) ||
23160
throw(ArgumentError("`eig_full!` requires square tensor V with isomorphic domain and complex `scalartype`"))
24161

25162
return nothing
26163
end
27164

28-
function MAK.initialize_output(::typeof(MAK.eig_full!), t::AbstractTensorMap,
29-
::MAK.LAPACK_EigAlgorithm)
165+
function MatrixAlgebraKit.initialize_output(::typeof(eigh_full!), t::AbstractTensorMap,
166+
::MatrixAlgebraKit.AbstractAlgorithm)
167+
V_D = fuse(domain(t))
168+
T = real(scalartype(t))
169+
D = DiagonalTensorMap{T}(undef, V_D)
170+
V = similar(t, codomain(t) V_D)
171+
return D, V
172+
end
173+
174+
function MatrixAlgebraKit.initialize_output(::typeof(eig_full!), t::AbstractTensorMap,
175+
::MatrixAlgebraKit.AbstractAlgorithm)
176+
V_D = fuse(domain(t))
30177
Tc = complex(scalartype(t))
31-
V_diag = fuse(domain(t))
32-
return DiagonalTensorMap{Tc}(undef, V_diag), similar(t, Tc, domain(t) V_diag)
178+
D = DiagonalTensorMap{Tc}(undef, V_D)
179+
V = similar(t, Tc, codomain(t) V_D)
180+
return D, V
33181
end
34182

35-
function MAK.eig_full!(t::AbstractTensorMap, (D, V), alg::MAK.LAPACK_EigAlgorithm)
36-
MAK.check_input(MAK.eig_full!, t, (D, V))
37-
foreachblock(t, D, V) do (_, (b, d, v))
38-
d′, v′ = MAK.eig_full!(b, (d, v), alg)
39-
# deal with the case where the output is not the same as the input
40-
d === d′ || copyto!(d, d′)
41-
v === v′ || copyto!(v, v′)
42-
return nothing
183+
for f in (:eigh_full!, :eig_full!)
184+
@eval function MatrixAlgebraKit.$f(t::AbstractTensorMap, (D, V),
185+
alg::BlockAlgorithm)
186+
MatrixAlgebraKit.check_input($f, t, (D, V))
187+
188+
foreachblock(t, D, V; alg.scheduler) do _, (b, d, v)
189+
d′, v′ = $f(b, (d, v), alg.alg)
190+
# deal with the case where the output is not the same as the input
191+
d === d′ || copyto!(d, d′)
192+
v === v′ || copyto!(v, v′)
193+
return nothing
194+
end
195+
196+
return D, V
43197
end
44-
return D, V
45198
end
46199

47-
function MAK.default_eig_algorithm(::TensorMap{<:LinearAlgebra.BlasFloat}; kwargs...)
48-
return MAK.LAPACK_Expert(; kwargs...)
200+
function MatrixAlgebraKit.default_eig_algorithm(t::AbstractTensorMap{<:BlasFloat};
201+
scheduler=default_blockscheduler(t),
202+
kwargs...)
203+
return BlockAlgorithm(LAPACK_Expert(; kwargs...), scheduler)
204+
end
205+
function MatrixAlgebraKit.default_eigh_algorithm(t::AbstractTensorMap{<:BlasFloat};
206+
scheduler=default_blockscheduler(t),
207+
kwargs...)
208+
return BlockAlgorithm(LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...),
209+
scheduler)
49210
end

0 commit comments

Comments
 (0)