Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multithreaded implementations using OhMyThreads #117

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6' # LTS version
- '1.9' # minimal version
- '1' # automatically expands to the latest stable 1.x release of Julia
os:
- ubuntu-latest
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@ version = "0.12.3"
HalfIntegers = "f0d1745a-41c9-11e9-1dd9-e5d34d218721"
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
@@ -27,8 +28,9 @@ ChainRulesTestUtils = "1"
Combinatorics = "1"
FiniteDifferences = "0.12"
HalfIntegers = "1"
LinearAlgebra = "1"
LRUCache = "1.0.2"
LinearAlgebra = "1"
OhMyThreads = "0.5"
PackageExtensionCompat = "1"
Random = "1"
Strided = "2"
@@ -38,7 +40,7 @@ TestExtras = "0.2"
TupleTools = "1.1"
VectorInterface = "0.4"
WignerSymbols = "1,2"
julia = "1.6"
julia = "1.9"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4 changes: 4 additions & 0 deletions src/TensorKit.jl
Original file line number Diff line number Diff line change
@@ -91,6 +91,10 @@ using TupleTools
using TupleTools: StaticLength

using Strided
using OhMyThreads
# turn off multithreading by default
default_scheduler(::Type) = SerialScheduler()
default_scheduler(t) = default_scheduler(typeof(t))

using VectorInterface

71 changes: 45 additions & 26 deletions src/tensors/factorizations.jl
Original file line number Diff line number Diff line change
@@ -403,23 +403,26 @@ end
function tsvd!(t::AdjointTensorMap;
trunc::TruncationScheme=NoTruncation(),
p::Real=2,
alg::Union{SVD,SDD}=SDD())
u, s, vt, err = tsvd!(adjoint(t); trunc=trunc, p=p, alg=alg)
alg::Union{SVD,SDD}=SDD(),
scheduler::Scheduler=default_scheduler(t))
u, s, vt, err = tsvd!(adjoint(t); trunc, p, alg, scheduler)
return adjoint(vt), adjoint(s), adjoint(u), err
end

function tsvd!(t::TensorMap;
trunc::TruncationScheme=NoTruncation(),
p::Real=2,
alg::Union{SVD,SDD}=SDD())
#early return
trunc::TruncationScheme=NoTruncation(), p::Real=2,
alg::Union{SVD,SDD}=SDD(),
scheduler::Scheduler=default_scheduler(t))
InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:tsvd!)

# early return
if isempty(blocksectors(t))
truncerr = zero(real(scalartype(t)))
return _empty_svdtensors(t)..., truncerr
end

S = spacetype(t)
Udata, Σdata, Vdata, dims = _compute_svddata!(t, alg)
Udata, Σdata, Vdata, dims = _compute_svddata!(t, alg; scheduler)
if !isa(trunc, NoTruncation)
Σdata, truncerr = _truncate!(Σdata, trunc, p)
Udata, Σdata, Vdata, dims = _implement_svdtruncation!(t, Udata, Σdata, Vdata, dims)
@@ -437,26 +440,42 @@ function tsvd!(t::TensorMap;
end

# helper functions

function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD})
InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:tsvd!)
I = sectortype(t)
A = storagetype(t)
Udata = SectorDict{I,A}()
Vdata = SectorDict{I,A}()
dims = SectorDict{I,Int}()
local Σdata
for (c, b) in blocks(t)
U, Σ, V = MatrixAlgebra.svd!(b, alg)
Udata[c] = U
Vdata[c] = V
if @isdefined Σdata # cannot easily infer the type of Σ, so use this construction
Σdata[c] = Σ
else
Σdata = SectorDict(c => Σ)
end
dims[c] = length(Σ)
function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD};
scheduler::Scheduler=default_scheduler(t))
Tdata = blocks(t)
Tkeys = keys(Tdata)
Tvals = values(Tdata)

Uvals = similar(Tvals)
Σtype = Core.Compiler.return_type(similar,
Tuple{eltype(Tvals),Type{real(scalartype(t))},Int})
Σvals = similar(Tvals, Σtype)
Vvals = similar(Tvals)
dimsvals = similar(Tvals, Int)

tforeach(enumerate(Tvals); scheduler) do (i, b)
Uvals[i], Σvals[i], Vvals[i] = MatrixAlgebra.svd!(b, alg)
return dimsvals[i] = length(Σvals[i])
end

# TODO: do we need copys of the keys?
Udata = SectorDict{eltype(Tkeys),eltype(Uvals)}(copy(Tkeys), Uvals)
Σdata = SectorDict{eltype(Tkeys),eltype(Σvals)}(copy(Tkeys), Σvals)
Vdata = SectorDict{eltype(Tkeys),eltype(Vvals)}(copy(Tkeys), Vvals)
dims = SectorDict{eltype(Tkeys),eltype(dimsvals)}(copy(Tkeys), dimsvals)

return Udata, Σdata, Vdata, dims
end
# scheduler ignored for trivial tensormap
function _compute_svddata!(t::TrivialTensorMap, alg::Union{SVD,SDD};
scheduler::Scheduler=default_scheduler(t))
U, S, V = MatrixAlgebra.svd!(t.data, alg)

Udata = SectorDict(Trivial() => U)
Σdata = SectorDict(Trivial() => S)
Vdata = SectorDict(Trivial() => V)
dims = SectorDict(Trivial() => length(S))

return Udata, Σdata, Vdata, dims
end

94 changes: 42 additions & 52 deletions src/tensors/indexmanipulations.jl
Original file line number Diff line number Diff line change
@@ -311,7 +311,8 @@
fusiontreetransform,
α::Number,
β::Number,
backend::Backend...) where {S,N₁,N₂}
backend::Backend...;
scheduler::Scheduler=default_scheduler(tdst)) where {S,N₁,N₂}
@boundscheck begin
all(i -> space(tsrc, p₁[i]) == space(tdst, i), 1:N₁) ||
throw(SpaceMismatch("source = $(codomain(tsrc))←$(domain(tsrc)),
@@ -321,74 +322,63 @@
dest = $(codomain(tdst))←$(domain(tdst)), p₁ = $(p₁), p₂ = $(p₂)"))
end

I = sectortype(S)
# special case for trivial permutations
if p₁ == codomainind(tsrc) && p₂ == domainind(tsrc)
add!(tdst, tsrc, α, β)
elseif I === Trivial
_add_trivial_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend...)
elseif FusionStyle(I) isa UniqueFusion
_add_abelian_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend...)
else
_add_general_kernel!(tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β, backend...)
return add!(tdst, tsrc, α, β; scheduler)
end

_add_transform!(sectortype(S), tdst, tsrc, (p₁, p₂), fusiontreetransform, α, β,
backend...; scheduler)
return tdst
end

# internal methods: no argument types
function _add_trivial_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backend...)
function _add_transform!(::Type{Trivial}, tdst, tsrc, p, fusiontreetransform, α, β,
backend...; scheduler::Scheduler)
TO.tensoradd!(tdst[], p, tsrc[], :N, α, β, backend...)
return nothing
end

function _add_abelian_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backend...)
if Threads.nthreads() > 1
Threads.@sync for (f₁, f₂) in fusiontrees(tsrc)
Threads.@spawn _add_abelian_block!(tdst, tsrc, p, fusiontreetransform,
f₁, f₂, α, β, backend...)
end
else
for (f₁, f₂) in fusiontrees(tsrc)
_add_abelian_block!(tdst, tsrc, p, fusiontreetransform,
f₁, f₂, α, β, backend...)
end
end
return tdst
function _add_transform!(::Type{I}, tdst, tsrc, p, fusiontreetransform, α, β, backend...;
scheduler::Scheduler) where {I<:Sector}
return __add_transform!(FusionStyle(I), tdst, tsrc, p, fusiontreetransform, α, β,
backend...; scheduler)
end

function _add_abelian_block!(tdst, tsrc, p, fusiontreetransform, f₁, f₂, α, β, backend...)
(f₁′, f₂′), coeff = first(fusiontreetransform(f₁, f₂))
TO.tensoradd!(tdst[f₁′, f₂′], p, tsrc[f₁, f₂], :N, α * coeff, β, backend...)
function __add_transform!(::UniqueFusion, tdst, tsrc, p, fusiontreetransform, α, β,
backend...; scheduler::Scheduler)
tforeach(fusiontrees(tsrc); scheduler) do (f₁, f₂)
(f₁′, f₂′), coeff = first(fusiontreetransform(f₁, f₂))
TO.tensoradd!(tdst[f₁′, f₂′], p, tsrc[f₁, f₂], :N, α * coeff, β, backend...)
return nothing
end
return nothing
end

function _add_general_kernel!(tdst, tsrc, p, fusiontreetransform, α, β, backend...)
if iszero(β)
tdst = zerovector!(tdst)
elseif β != 1
tdst = scale!(tdst, β)
end
if Threads.nthreads() > 1
Threads.@sync for s₁ in sectors(codomain(tsrc)), s₂ in sectors(domain(tsrc))
Threads.@spawn _add_nonabelian_sector!(tdst, tsrc, p, fusiontreetransform, s₁,
s₂, α, β, backend...)
end
else
for (f₁, f₂) in fusiontrees(tsrc)
# TODO: find a way to merge implementations of serial and parallel versions
# TODO: invert the way fusiontreetransform is made, so we can loop over output trees instead of input trees
function __add_transform!(::FusionStyle, tdst, tsrc, p, fusiontreetransform, α, β,
backend...; scheduler::Scheduler)
scale!(tdst, β)
β′ = One()
if scheduler isa SerialScheduler
# serial version does not need to care about simultaneous writes
tforeach(fusiontrees(tsrc); scheduler) do (f₁, f₂)
for ((f₁′, f₂′), coeff) in fusiontreetransform(f₁, f₂)
TO.tensoradd!(tdst[f₁′, f₂′], p, tsrc[f₁, f₂], :N, α * coeff, true,
TO.tensoradd!(tdst[f₁′, f₂′], p, tsrc[f₁, f₂], :N, α * coeff, β′,
backend...)
end
end
end
return nothing
end

function _add_nonabelian_sector!(tdst, tsrc, p, fusiontreetransform, s₁, s₂, α, β,
backend...)
for (f₁, f₂) in fusiontrees(tsrc)
(f₁.uncoupled == s₁ && f₂.uncoupled == s₂) || continue
for ((f₁′, f₂′), coeff) in fusiontreetransform(f₁, f₂)
TO.tensoradd!(tdst[f₁′, f₂′], p, tsrc[f₁, f₂], :N, α * coeff, true, backend...)
else
tforeach(Iterators.product(sectors(codomain(tsrc)), sectors(domain(tsrc)));

Check warning on line 371 in src/tensors/indexmanipulations.jl

Codecov / codecov/patch

src/tensors/indexmanipulations.jl#L371

Added line #L371 was not covered by tests
scheduler) do (s₁, s₂)
# each task will only write to a fixed set out output sectors, so no concurrent writes.
for (f₁, f₂) in fusiontrees(tsrc)
(f₁.uncoupled == s₁ && f₂.uncoupled == s₂) || continue
for ((f₁′, f₂′), coeff) in fusiontreetransform(f₁, f₂)
TO.tensoradd!(tdst[f₁′, f₂′], p, tsrc[f₁, f₂], :N, α * coeff, β′,
backend...)
end
end
return nothing
end
end
return nothing
27 changes: 15 additions & 12 deletions src/tensors/linalg.jl
Original file line number Diff line number Diff line change
@@ -241,24 +241,27 @@ end
# TensorMap multiplication
function LinearAlgebra.mul!(tC::AbstractTensorMap,
tA::AbstractTensorMap,
tB::AbstractTensorMap, α=true, β=false)
tB::AbstractTensorMap, α=true, β=false;
scheduler::Scheduler=default_scheduler(tC))
if !(codomain(tC) == codomain(tA) && domain(tC) == domain(tB) &&
domain(tA) == codomain(tB))
throw(SpaceMismatch("$(space(tC)) ≠ $(space(tA)) * $(space(tB))"))
end
for c in blocksectors(tC)
if hasblock(tA, c) # then also tB should have such a block
A = block(tA, c)
B = block(tB, c)
C = block(tC, c)
mul!(StridedView(C), StridedView(A), StridedView(B), α, β)
elseif β != one(β)
rmul!(block(tC, c), β)
end
end
_mul!(c) = _mul_block!(c, tC, tA, tB, α, β)
tforeach(_mul!, blocksectors(tC); scheduler)
return tC
end
# TODO: reconsider wrapping the blocks in a StridedView, consider spawning threads for different blocks
function _mul_block!(c, tC, tA, tB, α, β)
if hasblock(tA, c) # then also tB should have such a block
A = block(tA, c)
B = block(tB, c)
C = block(tC, c)
mul!(StridedView(C), StridedView(A), StridedView(B), α, β)
elseif β != one(β)
rmul!(block(tC, c), β)
end
return nothing
end

# TensorMap inverse
function Base.inv(t::AbstractTensorMap)
Loading