Skip to content

Commit bb01e1a

Browse files
committed
rework default scheduler settings
1 parent cc6b4ea commit bb01e1a

6 files changed

+40
-32
lines changed

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
1010
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
1111
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
12+
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
1213
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1314
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
1415
TensorKitSectors = "13a9c161-d5da-41f0-bcbd-e1a08ae0647f"
@@ -35,6 +36,7 @@ LinearAlgebra = "1"
3536
OhMyThreads = "0.7.0"
3637
PackageExtensionCompat = "1"
3738
Random = "1"
39+
ScopedValues = "1.3.0"
3840
SparseArrays = "1"
3941
Strided = "2"
4042
TensorKitSectors = "0.1"

src/TensorKit.jl

+2
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ const TO = TensorOperations
102102

103103
using LRUCache
104104
using OhMyThreads
105+
using ScopedValues
105106

106107
using TensorKitSectors
107108
import TensorKitSectors: dim, BraidingStyle, FusionStyle, ,
@@ -185,6 +186,7 @@ include("spaces/vectorspaces.jl")
185186
#-------------------------------------
186187
# general definitions
187188
include("tensors/abstracttensor.jl")
189+
include("tensors/backends.jl")
188190
include("tensors/blockiterator.jl")
189191
include("tensors/tensor.jl")
190192
include("tensors/adjoint.jl")

src/tensors/backends.jl

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Scheduler implementation
2+
# ------------------------
3+
const blockscheduler = ScopedValue{Scheduler}(SerialScheduler())
4+
const subblockscheduler = ScopedValue{Scheduler}(SerialScheduler())
5+
6+
# Backend implementation
7+
# ----------------------
8+
# TODO: figure out a name
9+
# TODO: what should be the default scheduler?
10+
@kwdef struct TensorKitBackend{B<:AbstractBackend,BS,SBS} <: AbstractBackend
11+
arraybackend::B = TO.DefaultBackend()
12+
blockscheduler::BS = blockscheduler[]
13+
subblockscheduler::SBS = subblockscheduler[]
14+
end
15+
16+
function TO.select_backend(::typeof(TO.tensoradd!), C::AbstractTensorMap,
17+
A::AbstractTensorMap)
18+
return TensorKitBackend()
19+
end
20+
function TO.select_backend(::typeof(TO.tensortrace!), C::AbstractTensorMap,
21+
A::AbstractTensorMap)
22+
return TensorKitBackend()
23+
end
24+
function TO.select_backend(::typeof(TO.tensorcontract!), C::AbstractTensorMap,
25+
A::AbstractTensorMap, B::AbstractTensorMap)
26+
return TensorKitBackend()
27+
end

src/tensors/indexmanipulations.jl

+7-6
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ function add_transform_kernel!(tdst::TensorMap,
519519
structure_src = transformer.structure_src.fusiontreestructure
520520

521521
tforeach(transformer.rows, transformer.cols, transformer.vals;
522-
backend.scheduler) do row, col, val
522+
scheduler=backend.subblockscheduler) do row, col, val
523523
sz_dst, str_dst, offset_dst = structure_dst[col]
524524
subblock_dst = StridedView(tdst.data, sz_dst, str_dst, offset_dst)
525525

@@ -546,7 +546,7 @@ function add_transform_kernel!(tdst::TensorMap,
546546
rows = rowvals(transformer.matrix)
547547
vals = nonzeros(transformer.matrix)
548548

549-
tforeach(axes(transformer.matrix, 2); backend.scheduler) do j
549+
tforeach(axes(transformer.matrix, 2); scheduler=backend.subblockscheduler) do j
550550
sz_dst, str_dst, offset_dst = structure_dst[j]
551551
subblock_dst = StridedView(tdst.data, sz_dst, str_dst, offset_dst)
552552
nzrows = nzrange(transformer.matrix, j)
@@ -602,7 +602,7 @@ end
602602

603603
function _add_abelian_kernel!(tdst, tsrc, p, fusiontreetransform, α, β,
604604
backend::TensorKitBackend, allocator)
605-
tforeach(fusiontrees(tsrc); backend.scheduler) do (f₁, f₂)
605+
tforeach(fusiontrees(tsrc); scheduler=backend.subblockscheduler) do (f₁, f₂)
606606
return _add_abelian_block!(tdst, tsrc, p, fusiontreetransform,
607607
f₁, f₂, α, β, backend.arraybackend, allocator)
608608
end
@@ -624,16 +624,17 @@ function _add_general_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backen
624624
tdst = scale!(tdst, β)
625625
end
626626
β′ = One()
627-
if backend.scheduler isa SerialScheduler
627+
if backend.subblockscheduler isa SerialScheduler
628628
for (f₁, f₂) in fusiontrees(tsrc)
629629
for ((f₁′, f₂′), coeff) in fusiontreetransform(f₁, f₂)
630630
@inbounds TO.tensoradd!(tdst[f₁′, f₂′], tsrc[f₁, f₂], p, false, α * coeff,
631631
β′, backend.arraybackend, allocator)
632632
end
633633
end
634634
else
635-
tforeach(Iterators.product(sectors(codomain(tsrc)), sectors(domain(tsrc)))) do (s₁,
636-
s₂)
635+
tforeach(Iterators.product(sectors(codomain(tsrc)), sectors(domain(tsrc)));
636+
scheduler=backend.subblockscheduler) do (s₁,
637+
s₂)
637638
return _add_nonabelian_sector!(tdts, tsrc, p, fusiontreetransform, s₁, s₂, α,
638639
β′, backend.arraybackend, allocator)
639640
end

src/tensors/tensoroperations.jl

-24
Original file line numberDiff line numberDiff line change
@@ -144,30 +144,6 @@ TO.tensorcost(t::AbstractTensorMap, i::Int) = dim(space(t, i))
144144
# IMPLEMENTATONS
145145
#----------------
146146

147-
# Backend implementation
148-
# ----------------------
149-
# TODO: figure out a name
150-
# TODO: what should be the default scheduler?
151-
# TODO: should we allow a separate scheduler for "blocks" and "subblocks"
152-
@kwdef struct TensorKitBackend{B<:AbstractBackend,BS,SBS} <: AbstractBackend
153-
arraybackend::B = TO.DefaultBackend()
154-
blockscheduler::BS = SerialScheduler()
155-
subblockscheduler::SBS = SerialScheduler()
156-
end
157-
158-
function TO.select_backend(::typeof(TO.tensoradd!), C::AbstractTensorMap,
159-
A::AbstractTensorMap)
160-
return TensorKitBackend()
161-
end
162-
function TO.select_backend(::typeof(TO.tensortrace!), C::AbstractTensorMap,
163-
A::AbstractTensorMap)
164-
return TensorKitBackend()
165-
end
166-
function TO.select_backend(::typeof(TO.tensorcontract!), C::AbstractTensorMap,
167-
A::AbstractTensorMap, B::AbstractTensorMap)
168-
return TensorKitBackend()
169-
end
170-
171147
# Trace implementation
172148
#----------------------
173149
"""

test/runtests.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ sectorlist = (Z2Irrep, Z3Irrep, Z4Irrep, Z3Irrep ⊠ Z4Irrep,
5454
Z2Irrep FibonacciAnyon FibonacciAnyon)
5555

5656
Ti = time()
57-
include("fusiontrees.jl")
58-
include("spaces.jl")
57+
# include("fusiontrees.jl")
58+
# include("spaces.jl")
5959
include("tensors.jl")
6060
include("diagonal.jl")
6161
include("planar.jl")

0 commit comments

Comments
 (0)