Skip to content

Commit a264735

Browse files
lkdvosJutho
andauthored
Add dedicated BlockIterators (#206)
* Add block iterator for TensorMap * Add block iterator for DiagonalTensorMap * Add block iterator for AdjointTensorMap * Add documentation `blocktype` * Add test * Add check diagonal * some more fixes and tests * sort blocksectors of ProductSpace * fix blocktype of Diagonal --------- Co-authored-by: Jutho <[email protected]>
1 parent 660bdf7 commit a264735

11 files changed

+130
-29
lines changed

docs/src/lib/tensors.md

+1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ spacetype(::Type{<:AbstractTensorMap{<:Any,S}}) where {S}
8282
sectortype(::Type{TT}) where {TT<:AbstractTensorMap}
8383
field(::Type{TT}) where {TT<:AbstractTensorMap}
8484
storagetype
85+
blocktype
8586
```
8687

8788
To obtain information about the indices, you can use:

src/TensorKit.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ include("spaces/vectorspaces.jl")
184184
#-------------------------------------
185185
# general definitions
186186
include("tensors/abstracttensor.jl")
187-
# include("tensors/tensortreeiterator.jl")
187+
include("tensors/blockiterator.jl")
188188
include("tensors/tensor.jl")
189189
include("tensors/adjoint.jl")
190190
include("tensors/linalg.jl")

src/spaces/homspace.jl

+21-5
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,11 @@ function blocksectors(W::HomSpace)
8888
N₁ = length(codom)
8989
N₂ = length(dom)
9090
I = sectortype(W)
91-
if N₁ == 0 || N₂ == 0
92-
return (one(I),)
93-
elseif N₂ <= N₁
94-
return sort!(filter!(c -> hasblock(codom, c), collect(blocksectors(dom))))
91+
# TODO: is sort! still necessary now that blocksectors of ProductSpace is sorted?
92+
if N₂ <= N₁
93+
return sort!(filter!(c -> hasblock(codom, c), blocksectors(dom)))
9594
else
96-
return sort!(filter!(c -> hasblock(dom, c), collect(blocksectors(codom))))
95+
return sort!(filter!(c -> hasblock(dom, c), blocksectors(codom)))
9796
end
9897
end
9998

@@ -349,3 +348,20 @@ function fusionblockstructure(W::HomSpace, ::GlobalLRUCache)
349348
end
350349
return structure
351350
end
351+
352+
# Diagonal ranges
353+
#----------------
354+
# TODO: is this something we want to cache?
355+
function diagonalblockstructure(W::HomSpace)
356+
((numin(W) == numout(W) == 1) && domain(W) == codomain(W)) ||
357+
throw(SpaceMismatch("Diagonal only support on V←V with a single space V"))
358+
structure = SectorDict{sectortype(W),UnitRange{Int}}() # range
359+
offset = 0
360+
dom = domain(W)[1]
361+
for c in blocksectors(W)
362+
d = dim(dom, c)
363+
structure[c] = offset .+ (1:d)
364+
offset += d
365+
end
366+
return structure
367+
end

src/spaces/productspace.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ function blocksectors(P::ProductSpace{S,N}) where {S,N}
162162
end
163163
end
164164
end
165-
return bs
165+
return sort!(bs)
166166
end
167167

168168
"""

src/tensors/abstracttensor.jl

+10
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ sectortype(t::AbstractTensorMap) = sectortype(typeof(t))
195195
InnerProductStyle(t::AbstractTensorMap) = InnerProductStyle(typeof(t))
196196
field(t::AbstractTensorMap) = field(typeof(t))
197197
storagetype(t::AbstractTensorMap) = storagetype(typeof(t))
198+
blocktype(t::AbstractTensorMap) = blocktype(typeof(t))
198199
similarstoragetype(t::AbstractTensorMap, T=scalartype(t)) = similarstoragetype(typeof(t), T)
199200

200201
numout(t::AbstractTensorMap) = numout(typeof(t))
@@ -310,6 +311,15 @@ Return the matrix block of a tensor corresponding to a coupled sector `c`.
310311
See also [`blocks`](@ref), [`blocksectors`](@ref), [`blockdim`](@ref) and [`hasblock`](@ref).
311312
""" block
312313

314+
@doc """
315+
blocktype(t)
316+
317+
Return the type of the matrix blocks of a tensor.
318+
""" blocktype
319+
function blocktype(::Type{T}) where {T<:AbstractTensorMap}
320+
return Core.Compiler.return_type(block, Tuple{T,sectortype(T)})
321+
end
322+
313323
# Derived indexing behavior for tensors with trivial symmetry
314324
#-------------------------------------------------------------
315325
using TensorKit.Strided: SliceIndex

src/tensors/adjoint.jl

+15-5
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,21 @@ storagetype(::Type{AdjointTensorMap{T,S,N₁,N₂,TT}}) where {T,S,N₁,N₂,TT}
2525
#----------------------
2626
block(t::AdjointTensorMap, s::Sector) = block(parent(t), s)'
2727

28-
function blocks(t::AdjointTensorMap)
29-
iter = Base.Iterators.map(blocks(parent(t))) do (c, b)
30-
return c => b'
31-
end
32-
return iter
28+
blocks(t::AdjointTensorMap) = BlockIterator(t, blocks(parent(t)))
29+
30+
function blocktype(::Type{AdjointTensorMap{T,S,N₁,N₂,TT}}) where {T,S,N₁,N₂,TT}
31+
return Base.promote_op(adjoint, blocktype(TT))
32+
end
33+
34+
function Base.iterate(iter::BlockIterator{<:AdjointTensorMap}, state...)
35+
next = iterate(iter.structure, state...)
36+
isnothing(next) && return next
37+
(c, b), newstate = next
38+
return c => adjoint(b), newstate
39+
end
40+
41+
function Base.getindex(iter::BlockIterator{<:AdjointTensorMap}, c::Sector)
42+
return adjoint(Base.getindex(iter.structure, c))
3343
end
3444

3545
function Base.getindex(t::AdjointTensorMap{T,S,N₁,N₂},

src/tensors/blockiterator.jl

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
struct BlockIterator{T<:AbstractTensorMap,S}
3+
4+
Iterator over the blocks of type `T`, possibly holding some pre-computed data of type `S`
5+
"""
6+
struct BlockIterator{T<:AbstractTensorMap,S}
7+
t::T
8+
structure::S
9+
end
10+
11+
Base.IteratorSize(::BlockIterator) = Base.HasLength()
12+
Base.IteratorEltype(::BlockIterator) = Base.HasEltype()
13+
Base.eltype(::Type{<:BlockIterator{T}}) where {T} = blocktype(T)
14+
Base.length(iter::BlockIterator) = length(iter.structure)
15+
Base.isdone(iter::BlockIterator, state...) = Base.isdone(iter.structure, state...)

src/tensors/diagonal.jl

+17-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,23 @@ function block(d::DiagonalTensorMap, s::Sector)
110110
return Diagonal(view(d.data, 1:0))
111111
end
112112

113-
# TODO: is relying on generic AbstractTensorMap blocks sufficient?
113+
blocks(t::DiagonalTensorMap) = BlockIterator(t, diagonalblockstructure(space(t)))
114+
function blocktype(::Type{DiagonalTensorMap{T,S,A}}) where {T,S,A}
115+
return Diagonal{T,SubArray{T,1,A,Tuple{UnitRange{Int}},true}}
116+
end
117+
118+
function Base.iterate(iter::BlockIterator{<:DiagonalTensorMap}, state...)
119+
next = iterate(iter.structure, state...)
120+
isnothing(next) && return next
121+
(c, r), newstate = next
122+
return c => Diagonal(view(iter.t.data, r)), newstate
123+
end
124+
125+
function Base.getindex(iter::BlockIterator{<:DiagonalTensorMap}, c::Sector)
126+
sectortype(iter.t) === typeof(c) || throw(SectorMismatch())
127+
r = get(iter.structure, c, 1:0)
128+
return Diagonal(view(iter.t.data, r))
129+
end
114130

115131
# Indexing and getting and setting the data at the subblock level
116132
#-----------------------------------------------------------------

src/tensors/tensor.jl

+23-16
Original file line numberDiff line numberDiff line change
@@ -421,28 +421,35 @@ end
421421

422422
# Getting and setting the data at the block level
423423
#-------------------------------------------------
424-
function block(t::TensorMap, s::Sector)
425-
sectortype(t) == typeof(s) || throw(SectorMismatch())
426-
structure = fusionblockstructure(t).blockstructure
427-
(d₁, d₂), r = get(structure, s) do
424+
block(t::TensorMap, c::Sector) = blocks(t)[c]
425+
426+
blocks(t::TensorMap) = BlockIterator(t, fusionblockstructure(t).blockstructure)
427+
428+
function blocktype(::Type{TT}) where {TT<:TensorMap}
429+
A = storagetype(TT)
430+
T = eltype(A)
431+
return Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}
432+
end
433+
434+
function Base.iterate(iter::BlockIterator{<:TensorMap}, state...)
435+
next = iterate(iter.structure, state...)
436+
isnothing(next) && return next
437+
(c, (sz, r)), newstate = next
438+
return c => reshape(view(iter.t.data, r), sz), newstate
439+
end
440+
441+
function Base.getindex(iter::BlockIterator{<:TensorMap}, c::Sector)
442+
sectortype(iter.t) === typeof(c) || throw(SectorMismatch())
443+
(d₁, d₂), r = get(iter.structure, c) do
428444
# is s is not a key, at least one of the two dimensions will be zero:
429445
# it then does not matter where exactly we construct a view in `t.data`,
430446
# as it will have length zero anyway
431-
d₁′ = blockdim(codomain(t), s)
432-
d₂′ = blockdim(domain(t), s)
447+
d₁′ = blockdim(codomain(iter.t), c)
448+
d₂′ = blockdim(domain(iter.t), c)
433449
l = d₁′ * d₂′
434450
return (d₁′, d₂′), 1:l
435451
end
436-
return reshape(view(t.data, r), (d₁, d₂))
437-
end
438-
439-
function blocks(t::TensorMap)
440-
structure = fusionblockstructure(t).blockstructure
441-
iter = Base.Iterators.map(structure) do (c, ((d₁, d₂), r))
442-
b = reshape(view(t.data, r), (d₁, d₂))
443-
return c => b
444-
end
445-
return iter
452+
return reshape(view(iter.t.data, r), (d₁, d₂))
446453
end
447454

448455
# Indexing and getting and setting the data at the subblock level

test/diagonal.jl

+9
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3),
1414
@test space(t) == (V V)
1515
@test space(t') == (V V)
1616
@test dim(t) == dim(space(t))
17+
# blocks
18+
bs = @constinferred blocks(t)
19+
(c, b1), state = @constinferred Nothing iterate(bs)
20+
@test c == first(blocksectors(V V))
21+
next = @constinferred Nothing iterate(bs, state)
22+
b2 = @constinferred block(t, first(blocksectors(t)))
23+
@test b1 == b2
24+
@test eltype(bs) === typeof(b1) === TensorKit.blocktype(t)
25+
# basic linear algebra
1726
@test isa(@constinferred(norm(t)), real(T))
1827
@test norm(t)^2 dot(t, t)
1928
α = rand(T)

test/tensors.jl

+17
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,14 @@ for V in spacelist
9191
@test space(t) == (W one(W))
9292
@test domain(t) == one(W)
9393
@test typeof(t) == TensorMap{T,spacetype(t),5,0,Vector{T}}
94+
# blocks
95+
bs = @constinferred blocks(t)
96+
(c, b1), state = @constinferred Nothing iterate(bs)
97+
@test c == first(blocksectors(W))
98+
next = @constinferred Nothing iterate(bs, state)
99+
b2 = @constinferred block(t, first(blocksectors(t)))
100+
@test b1 == b2
101+
@test eltype(bs) === typeof(b1) === TensorKit.blocktype(t)
94102
end
95103
end
96104
@timedtestset "Tensor Dict conversion" begin
@@ -143,6 +151,15 @@ for V in spacelist
143151
@test dim(t) == dim(space(t))
144152
@test codomain(t) == codomain(W)
145153
@test domain(t) == domain(W)
154+
# blocks for adjoint
155+
bs = @constinferred blocks(t')
156+
(c, b1), state = @constinferred Nothing iterate(bs)
157+
@test c == first(blocksectors(W'))
158+
next = @constinferred Nothing iterate(bs, state)
159+
b2 = @constinferred block(t', first(blocksectors(t')))
160+
@test b1 == b2
161+
@test eltype(bs) === typeof(b1) === TensorKit.blocktype(t')
162+
# linear algebra
146163
@test isa(@constinferred(norm(t)), real(T))
147164
@test norm(t)^2 dot(t, t)
148165
α = rand(T)

0 commit comments

Comments
 (0)