|
1 |
| -function MAK.copy_input(::typeof(MAK.eig_full), t::AbstractTensorMap) |
2 |
| - return copy_oftype(t, factorisation_scalartype(MAK.eig_full!, t)) |
| 1 | +# Generic |
| 2 | +# ------- |
| 3 | +for f in (:eig_full, :eig_vals, :eig_trunc, :eigh_full, :eigh_vals, :eigh_trunc, :svd_full, |
| 4 | + :svd_compact, :svd_vals, :svd_trunc) |
| 5 | + @eval function MatrixAlgebraKit.copy_input(::typeof($f), |
| 6 | + t::AbstractTensorMap{<:BlasFloat}) |
| 7 | + T = factorisation_scalartype($f, t) |
| 8 | + return copy_oftype(t, T) |
| 9 | + end |
| 10 | +end |
| 11 | + |
| 12 | +# function factorisation_scalartype(::typeof(MAK.eig_full!), t::AbstractTensorMap) |
| 13 | +# T = scalartype(t) |
| 14 | +# return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T))))) |
| 15 | +# end |
| 16 | + |
| 17 | +# Singular value decomposition |
| 18 | +# ---------------------------- |
| 19 | +function MatrixAlgebraKit.check_input(::typeof(svd_full!), t::AbstractTensorMap, (U, S, Vᴴ)) |
| 20 | + V_cod = fuse(codomain(t)) |
| 21 | + V_dom = fuse(domain(t)) |
| 22 | + |
| 23 | + (U isa AbstractTensorMap && |
| 24 | + scalartype(U) == scalartype(t) && |
| 25 | + space(U) == (codomain(t) ← V_cod)) || |
| 26 | + throw(ArgumentError("`svd_full!` requires unitary tensor U with same `scalartype`")) |
| 27 | + (S isa AbstractTensorMap && |
| 28 | + scalartype(S) == real(scalartype(t)) && |
| 29 | + space(S) == (V_cod ← V_dom)) || |
| 30 | + throw(ArgumentError("`svd_full!` requires rectangular tensor S with real `scalartype`")) |
| 31 | + (Vᴴ isa AbstractTensorMap && |
| 32 | + scalartype(Vᴴ) == scalartype(t) && |
| 33 | + space(Vᴴ) == (V_dom ← domain(t))) || |
| 34 | + throw(ArgumentError("`svd_full!` requires unitary tensor Vᴴ with same `scalartype`")) |
| 35 | + |
| 36 | + return nothing |
| 37 | +end |
| 38 | + |
| 39 | +function MatrixAlgebraKit.check_input(::typeof(svd_compact!), t::AbstractTensorMap, |
| 40 | + (U, S, Vᴴ)) |
| 41 | + V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) |
| 42 | + |
| 43 | + (U isa AbstractTensorMap && |
| 44 | + scalartype(U) == scalartype(t) && |
| 45 | + space(U) == (codomain(t) ← V_cod)) || |
| 46 | + throw(ArgumentError("`svd_compact!` requires isometric tensor U with same `scalartype`")) |
| 47 | + (S isa DiagonalTensorMap && |
| 48 | + scalartype(S) == real(scalartype(t)) && |
| 49 | + space(S) == (V_cod ← V_dom)) || |
| 50 | + throw(ArgumentError("`svd_compact!` requires diagonal tensor S with real `scalartype`")) |
| 51 | + (Vᴴ isa AbstractTensorMap && |
| 52 | + scalartype(Vᴴ) == scalartype(t) && |
| 53 | + space(Vᴴ) == (V_dom ← domain(t))) || |
| 54 | + throw(ArgumentError("`svd_compact!` requires isometric tensor Vᴴ with same `scalartype`")) |
| 55 | + |
| 56 | + return nothing |
| 57 | +end |
| 58 | + |
| 59 | +# TODO: svd_vals |
| 60 | + |
| 61 | +function MatrixAlgebraKit.initialize_output(::typeof(svd_full!), t::AbstractTensorMap, |
| 62 | + ::MatrixAlgebraKit.AbstractAlgorithm) |
| 63 | + V_cod = fuse(codomain(t)) |
| 64 | + V_dom = fuse(domain(t)) |
| 65 | + U = similar(t, domain(t) ← V_cod) |
| 66 | + S = similar(t, real(scalartype(t)), V_cod ← V_dom) |
| 67 | + Vᴴ = similar(t, domain(t) ← V_dom) |
| 68 | + return U, S, Vᴴ |
| 69 | +end |
| 70 | + |
| 71 | +function MatrixAlgebraKit.initialize_output(::typeof(svd_compact!), t::AbstractTensorMap, |
| 72 | + ::MatrixAlgebraKit.AbstractAlgorithm) |
| 73 | + V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t))) |
| 74 | + U = similar(t, domain(t) ← V_cod) |
| 75 | + S = DiagonalTensorMap{real(scalartype(t))}(undef, V_cod ← V_dom) |
| 76 | + Vᴴ = similar(t, domain(t) ← V_dom) |
| 77 | + return U, S, Vᴴ |
| 78 | +end |
| 79 | + |
| 80 | +# TODO: svd_vals |
| 81 | + |
| 82 | +function MatrixAlgebraKit.svd_full!(t::AbstractTensorMap, (U, S, Vᴴ), |
| 83 | + alg::BlockAlgorithm) |
| 84 | + MatrixAlgebraKit.check_input(svd_full!, t, (U, S, Vᴴ)) |
| 85 | + |
| 86 | + foreachblock(t, U, S, Vᴴ; alg.scheduler) do _, (b, u, s, vᴴ) |
| 87 | + if isempty(b) # TODO: remove once MatrixAlgebraKit supports empty matrices |
| 88 | + one!(length(u) > 0 ? u : vᴴ) |
| 89 | + zerovector!(s) |
| 90 | + else |
| 91 | + u′, s′, vᴴ′ = MatrixAlgebraKit.svd_full!(b, (u, s, vᴴ), alg.alg) |
| 92 | + # deal with the case where the output is not the same as the input |
| 93 | + u === u′ || copyto!(u, u′) |
| 94 | + s === s′ || copyto!(s, s′) |
| 95 | + vᴴ === vᴴ′ || copyto!(vᴴ, vᴴ′) |
| 96 | + end |
| 97 | + return nothing |
| 98 | + end |
| 99 | + |
| 100 | + return U, S, Vᴴ |
| 101 | +end |
| 102 | + |
| 103 | +function MatrixAlgebraKit.svd_compact!(t::AbstractTensorMap, (U, S, Vᴴ), |
| 104 | + alg::BlockAlgorithm) |
| 105 | + MatrixAlgebraKit.check_input(svd_compact!, t, (U, S, Vᴴ)) |
| 106 | + |
| 107 | + foreachblock(t, U, S, Vᴴ; alg.scheduler) do _, (b, u, s, vᴴ) |
| 108 | + u′, s′, vᴴ′ = svd_compact!(b, (u, s, vᴴ), alg.alg) |
| 109 | + # deal with the case where the output is not the same as the input |
| 110 | + u === u′ || copyto!(u, u′) |
| 111 | + s === s′ || copyto!(s, s′) |
| 112 | + vᴴ === vᴴ′ || copyto!(vᴴ, vᴴ′) |
| 113 | + return nothing |
| 114 | + end |
| 115 | + |
| 116 | + return U, S, Vᴴ |
| 117 | +end |
| 118 | + |
| 119 | +function MatrixAlgebraKit.default_svd_algorithm(t::AbstractTensorMap{<:BlasFloat}; |
| 120 | + scheduler=default_blockscheduler(t), |
| 121 | + kwargs...) |
| 122 | + return BlockAlgorithm(LAPACK_DivideAndConquer(; kwargs...), scheduler) |
3 | 123 | end
|
4 | 124 |
|
5 |
| -function factorisation_scalartype(::typeof(MAK.eig_full!), t::AbstractTensorMap) |
6 |
| - T = complex(scalartype(t)) |
7 |
| - return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T))))) |
| 125 | +# Eigenvalue decomposition |
| 126 | +# ------------------------ |
| 127 | +function MatrixAlgebraKit.check_input(::typeof(eigh_full!), t::AbstractTensorMap, (D, V)) |
| 128 | + domain(t) == codomain(t) || |
| 129 | + throw(ArgumentError("Eigenvalue decomposition requires square input tensor")) |
| 130 | + |
| 131 | + V_D = fuse(domain(t)) |
| 132 | + |
| 133 | + (D isa DiagonalTensorMap && |
| 134 | + scalartype(D) == real(scalartype(t)) && |
| 135 | + V_D == space(D, 1)) || |
| 136 | + throw(ArgumentError("`eigh_full!` requires diagonal tensor D with isomorphic domain and real `scalartype`")) |
| 137 | + |
| 138 | + V isa AbstractTensorMap && |
| 139 | + scalartype(V) == scalartype(t) && |
| 140 | + space(V) == (codomain(t) ← V_D) || |
| 141 | + throw(ArgumentError("`eigh_full!` requires square tensor V with isomorphic domain and equal `scalartype`")) |
| 142 | + |
| 143 | + return nothing |
8 | 144 | end
|
9 | 145 |
|
10 |
| -function MAK.check_input(::typeof(MAK.eig_full!), t::AbstractTensorMap, (D, V)) |
| 146 | +function MatrixAlgebraKit.check_input(::typeof(eig_full!), t::AbstractTensorMap, (D, V)) |
11 | 147 | domain(t) == codomain(t) ||
|
12 | 148 | throw(ArgumentError("Eigenvalue decomposition requires square input tensor"))
|
13 | 149 | Tc = complex(scalartype(t))
|
| 150 | + V_D = fuse(domain(t)) |
14 | 151 |
|
15 | 152 | (D isa DiagonalTensorMap &&
|
16 | 153 | scalartype(D) == Tc &&
|
17 |
| - fuse(domain(t)) == space(D, 1)) || |
| 154 | + V_D == space(D, 1)) || |
18 | 155 | throw(ArgumentError("`eig_full!` requires diagonal tensor D with isomorphic domain and complex `scalartype`"))
|
19 | 156 |
|
20 | 157 | V isa AbstractTensorMap &&
|
21 | 158 | scalartype(V) == Tc &&
|
22 |
| - space(V) == (codomain(t) ← codomain(D)) || |
| 159 | + space(V) == (codomain(t) ← V_D) || |
23 | 160 | throw(ArgumentError("`eig_full!` requires square tensor V with isomorphic domain and complex `scalartype`"))
|
24 | 161 |
|
25 | 162 | return nothing
|
26 | 163 | end
|
27 | 164 |
|
28 |
| -function MAK.initialize_output(::typeof(MAK.eig_full!), t::AbstractTensorMap, |
29 |
| - ::MAK.LAPACK_EigAlgorithm) |
| 165 | +function MatrixAlgebraKit.initialize_output(::typeof(eigh_full!), t::AbstractTensorMap, |
| 166 | + ::MatrixAlgebraKit.AbstractAlgorithm) |
| 167 | + V_D = fuse(domain(t)) |
| 168 | + T = real(scalartype(t)) |
| 169 | + D = DiagonalTensorMap{T}(undef, V_D) |
| 170 | + V = similar(t, codomain(t) ← V_D) |
| 171 | + return D, V |
| 172 | +end |
| 173 | + |
| 174 | +function MatrixAlgebraKit.initialize_output(::typeof(eig_full!), t::AbstractTensorMap, |
| 175 | + ::MatrixAlgebraKit.AbstractAlgorithm) |
| 176 | + V_D = fuse(domain(t)) |
30 | 177 | Tc = complex(scalartype(t))
|
31 |
| - V_diag = fuse(domain(t)) |
32 |
| - return DiagonalTensorMap{Tc}(undef, V_diag), similar(t, Tc, domain(t) ← V_diag) |
| 178 | + D = DiagonalTensorMap{Tc}(undef, V_D) |
| 179 | + V = similar(t, Tc, codomain(t) ← V_D) |
| 180 | + return D, V |
33 | 181 | end
|
34 | 182 |
|
35 |
| -function MAK.eig_full!(t::AbstractTensorMap, (D, V), alg::MAK.LAPACK_EigAlgorithm) |
36 |
| - MAK.check_input(MAK.eig_full!, t, (D, V)) |
37 |
| - foreachblock(t, D, V) do (_, (b, d, v)) |
38 |
| - d′, v′ = MAK.eig_full!(b, (d, v), alg) |
39 |
| - # deal with the case where the output is not the same as the input |
40 |
| - d === d′ || copyto!(d, d′) |
41 |
| - v === v′ || copyto!(v, v′) |
42 |
| - return nothing |
| 183 | +for f in (:eigh_full!, :eig_full!) |
| 184 | + @eval function MatrixAlgebraKit.$f(t::AbstractTensorMap, (D, V), |
| 185 | + alg::BlockAlgorithm) |
| 186 | + MatrixAlgebraKit.check_input($f, t, (D, V)) |
| 187 | + |
| 188 | + foreachblock(t, D, V; alg.scheduler) do _, (b, d, v) |
| 189 | + d′, v′ = $f(b, (d, v), alg.alg) |
| 190 | + # deal with the case where the output is not the same as the input |
| 191 | + d === d′ || copyto!(d, d′) |
| 192 | + v === v′ || copyto!(v, v′) |
| 193 | + return nothing |
| 194 | + end |
| 195 | + |
| 196 | + return D, V |
43 | 197 | end
|
44 |
| - return D, V |
45 | 198 | end
|
46 | 199 |
|
47 |
| -function MAK.default_eig_algorithm(::TensorMap{<:LinearAlgebra.BlasFloat}; kwargs...) |
48 |
| - return MAK.LAPACK_Expert(; kwargs...) |
| 200 | +function MatrixAlgebraKit.default_eig_algorithm(t::AbstractTensorMap{<:BlasFloat}; |
| 201 | + scheduler=default_blockscheduler(t), |
| 202 | + kwargs...) |
| 203 | + return BlockAlgorithm(LAPACK_Expert(; kwargs...), scheduler) |
| 204 | +end |
| 205 | +function MatrixAlgebraKit.default_eigh_algorithm(t::AbstractTensorMap{<:BlasFloat}; |
| 206 | + scheduler=default_blockscheduler(t), |
| 207 | + kwargs...) |
| 208 | + return BlockAlgorithm(LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...), |
| 209 | + scheduler) |
49 | 210 | end
|
0 commit comments