Skip to content

Commit a827476

Browse files
committed
Implement eig_full!
1 parent d4668c4 commit a827476

File tree

3 files changed

+54
-0
lines changed

3 files changed

+54
-0
lines changed

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.14.5"
66
[deps]
77
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9+
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
910
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
1011
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1112
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -31,6 +32,7 @@ Combinatorics = "1"
3132
FiniteDifferences = "0.12"
3233
LRUCache = "1.0.2"
3334
LinearAlgebra = "1"
35+
MatrixAlgebraKit = "0.1.1"
3436
PackageExtensionCompat = "1"
3537
Random = "1"
3638
SparseArrays = "1"

src/TensorKit.jl

+3
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ using TensorOperations: TensorOperations, @tensor, @tensoropt, @ncon, ncon
100100
using TensorOperations: IndexTuple, Index2Tuple, linearize, AbstractBackend
101101
const TO = TensorOperations
102102

103+
using MatrixAlgebraKit: MatrixAlgebraKit as MAK
104+
103105
using LRUCache
104106

105107
using TensorKitSectors
@@ -194,6 +196,7 @@ include("tensors/treetransformers.jl")
194196
include("tensors/indexmanipulations.jl")
195197
include("tensors/diagonal.jl")
196198
include("tensors/truncation.jl")
199+
include("tensors/matrixalgebrakit.jl")
197200
include("tensors/factorizations.jl")
198201
include("tensors/braidingtensor.jl")
199202

src/tensors/matrixalgebrakit.jl

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
function MAK.copy_input(::typeof(MAK.eig_full), t::AbstractTensorMap)
2+
return copy_oftype(t, factorisation_scalartype(MAK.eig_full!, t))
3+
end
4+
5+
function factorisation_scalartype(::typeof(MAK.eig_full!), t::AbstractTensorMap)
6+
T = complex(scalartype(t))
7+
return promote_type(ComplexF32, typeof(zero(T) / sqrt(abs2(one(T)))))
8+
end
9+
10+
function MAK.check_input(::typeof(MAK.eig_full!), t::AbstractTensorMap, (D, V))
11+
domain(t) == codomain(t) ||
12+
throw(ArgumentError("Eigenvalue decomposition requires square input tensor"))
13+
Tc = complex(scalartype(t))
14+
15+
(D isa DiagonalTensorMap &&
16+
scalartype(D) == Tc &&
17+
fuse(domain(t)) == space(D, 1)) ||
18+
throw(ArgumentError("`eig_full!` requires diagonal tensor D with isomorphic domain and complex `scalartype`"))
19+
20+
V isa AbstractTensorMap &&
21+
scalartype(V) == Tc &&
22+
space(V) == (codomain(t) codomain(D)) ||
23+
throw(ArgumentError("`eig_full!` requires square tensor V with isomorphic domain and complex `scalartype`"))
24+
25+
return nothing
26+
end
27+
28+
function MAK.initialize_output(::typeof(MAK.eig_full!), t::AbstractTensorMap,
29+
::MAK.LAPACK_EigAlgorithm)
30+
Tc = complex(scalartype(t))
31+
V_diag = fuse(domain(t))
32+
return DiagonalTensorMap{Tc}(undef, V_diag), similar(t, Tc, domain(t) V_diag)
33+
end
34+
35+
function MAK.eig_full!(t::AbstractTensorMap, (D, V), alg::MAK.LAPACK_EigAlgorithm)
36+
MAK.check_input(MAK.eig_full!, t, (D, V))
37+
foreachblock(t, D, V) do (_, (b, d, v))
38+
d′, v′ = MAK.eig_full!(b, (d, v), alg)
39+
# deal with the case where the output is not the same as the input
40+
d === d′ || copyto!(d, d′)
41+
v === v′ || copyto!(v, v′)
42+
return nothing
43+
end
44+
return D, V
45+
end
46+
47+
function MAK.default_eig_algorithm(::TensorMap{<:LinearAlgebra.BlasFloat}; kwargs...)
48+
return MAK.LAPACK_Expert(; kwargs...)
49+
end

0 commit comments

Comments
 (0)