Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-threading attempt III #203

Draft
wants to merge 18 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -6,8 +6,10 @@ version = "0.14.4"
[deps]
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
TensorKitSectors = "13a9c161-d5da-41f0-bcbd-e1a08ae0647f"
@@ -31,8 +33,10 @@ Combinatorics = "1"
FiniteDifferences = "0.12"
LRUCache = "1.0.2"
LinearAlgebra = "1"
OhMyThreads = "0.7.0"
PackageExtensionCompat = "1"
Random = "1"
ScopedValues = "1.3.0"
SparseArrays = "1"
Strided = "2"
TensorKitSectors = "0.1"
1 change: 0 additions & 1 deletion docs/src/lib/tensors.md
Original file line number Diff line number Diff line change
@@ -200,7 +200,6 @@ TensorKit.add_transpose!
```@docs
compose(::AbstractTensorMap, ::AbstractTensorMap)
trace_permute!
contract!
⊗(::AbstractTensorMap, ::AbstractTensorMap)
```

5 changes: 4 additions & 1 deletion src/TensorKit.jl
Original file line number Diff line number Diff line change
@@ -82,7 +82,7 @@ export OrthogonalFactorizationAlgorithm, QR, QRpos, QL, QLpos, LQ, LQpos, RQ, RQ

# tensor operations
export @tensor, @tensoropt, @ncon, ncon, @planar, @plansor
export scalar, add!, contract!
export scalar, add!

# truncation schemes
export notrunc, truncerr, truncdim, truncspace, truncbelow
@@ -101,6 +101,8 @@ using TensorOperations: IndexTuple, Index2Tuple, linearize, AbstractBackend
const TO = TensorOperations

using LRUCache
using OhMyThreads
using ScopedValues

using TensorKitSectors
import TensorKitSectors: dim, BraidingStyle, FusionStyle, ⊠, ⊗
@@ -184,6 +186,7 @@ include("spaces/vectorspaces.jl")
#-------------------------------------
# general definitions
include("tensors/abstracttensor.jl")
include("tensors/backends.jl")
include("tensors/blockiterator.jl")
include("tensors/tensor.jl")
include("tensors/adjoint.jl")
3 changes: 2 additions & 1 deletion src/planar/planaroperations.jl
Original file line number Diff line number Diff line change
@@ -142,7 +142,8 @@ function planarcontract!(C::AbstractTensorMap,
α::Number, β::Number,
backend, allocator)
if BraidingStyle(sectortype(C)) == Bosonic()
return contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator)
return TO.tensorcontract!(C, A, pA, false, B, pB, false, pAB,
α, β, backend, allocator)
end

codA, domA = codomainind(A), domainind(A)
84 changes: 84 additions & 0 deletions src/tensors/backends.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Scheduler implementation
# ------------------------
"""
const blockscheduler = ScopedValue{Scheduler}(SerialScheduler())

The default scheduler used when looping over different blocks in the matrix representation of a
tensor.

For controlling this value, see also [`set_blockscheduler`](@ref) and [`with_blockscheduler`](@ref).
"""
const blockscheduler = ScopedValue{Scheduler}(SerialScheduler())

"""
const subblockscheduler = ScopedValue{Scheduler}(SerialScheduler())

The default scheduler used when looping over different subblocks in a tensor.

For controlling this value, see also [`set_subblockscheduler`](@ref) and [`with_subblockscheduler`](@ref).
"""
const subblockscheduler = ScopedValue{Scheduler}(SerialScheduler())

function select_scheduler(scheduler=OhMyThreads.Implementation.NotGiven(); kwargs...)
return if scheduler == OhMyThreads.Implementation.NotGiven() && isempty(kwargs)
Threads.nthreads() == 1 ? SerialScheduler() : DynamicScheduler()

Check warning on line 24 in src/tensors/backends.jl

Codecov / codecov/patch

src/tensors/backends.jl#L22-L24

Added lines #L22 - L24 were not covered by tests
else
OhMyThreads.Implementation._scheduler_from_userinput(scheduler; kwargs...)

Check warning on line 26 in src/tensors/backends.jl

Codecov / codecov/patch

src/tensors/backends.jl#L26

Added line #L26 was not covered by tests
end
end

"""
with_blockscheduler(f, [scheduler]; kwargs...)

Run `f` in a scope where the `blockscheduler` is determined by `scheduler` and `kwargs...`.

See also [`with_subblockscheduler!`](@ref).
"""
@inline function with_blockscheduler(f, scheduler=OhMyThreads.Implementation.NotGiven();

Check warning on line 37 in src/tensors/backends.jl

Codecov / codecov/patch

src/tensors/backends.jl#L37

Added line #L37 was not covered by tests
kwargs...)
@with blockscheduler => select_scheduler(scheduler; kwargs...) f()

Check warning on line 39 in src/tensors/backends.jl

Codecov / codecov/patch

src/tensors/backends.jl#L39

Added line #L39 was not covered by tests
end

"""
with_subblockscheduler(f, [scheduler]; kwargs...)

Run `f` in a scope where the [`subblockscheduler`](@ref) is determined by `scheduler` and `kwargs...`.
The arguments to this function are either an `OhMyThreads.Scheduler` or a `Symbol` with optional
set of keywords arguments. For a detailed description, consult the
[`OhMyThreads` documentation](https://juliafolds2.github.io/OhMyThreads.jl/stable/refs/api/#Schedulers).

See also [`with_blockscheduler!`](@ref).
"""
@inline function with_subblockscheduler(f, scheduler=OhMyThreads.Implementation.NotGiven();

Check warning on line 52 in src/tensors/backends.jl

Codecov / codecov/patch

src/tensors/backends.jl#L52

Added line #L52 was not covered by tests
kwargs...)
@with subblockscheduler => select_scheduler(scheduler; kwargs...) f()

Check warning on line 54 in src/tensors/backends.jl

Codecov / codecov/patch

src/tensors/backends.jl#L54

Added line #L54 was not covered by tests
end

# Backend implementation
# ----------------------
# TODO: figure out a name
# TODO: what should be the default scheduler?
@kwdef struct TensorKitBackend{B<:AbstractBackend,BS,SBS} <: AbstractBackend
arraybackend::B = TO.DefaultBackend()
blockscheduler::BS = blockscheduler[]
subblockscheduler::SBS = subblockscheduler[]
end

function TO.select_backend(::typeof(TO.tensoradd!), C::AbstractTensorMap,

Check warning on line 67 in src/tensors/backends.jl

Codecov / codecov/patch

src/tensors/backends.jl#L67

Added line #L67 was not covered by tests
A::AbstractTensorMap)
return TensorKitBackend()

Check warning on line 69 in src/tensors/backends.jl

Codecov / codecov/patch

src/tensors/backends.jl#L69

Added line #L69 was not covered by tests
end
function TO.select_backend(::typeof(TO.tensortrace!), C::AbstractTensorMap,

Check warning on line 71 in src/tensors/backends.jl

Codecov / codecov/patch

src/tensors/backends.jl#L71

Added line #L71 was not covered by tests
A::AbstractTensorMap)
return TensorKitBackend()

Check warning on line 73 in src/tensors/backends.jl

Codecov / codecov/patch

src/tensors/backends.jl#L73

Added line #L73 was not covered by tests
end
function TO.select_backend(::typeof(TO.tensorcontract!), C::AbstractTensorMap,

Check warning on line 75 in src/tensors/backends.jl

Codecov / codecov/patch

src/tensors/backends.jl#L75

Added line #L75 was not covered by tests
A::AbstractTensorMap, B::AbstractTensorMap)
return TensorKitBackend()

Check warning on line 77 in src/tensors/backends.jl

Codecov / codecov/patch

src/tensors/backends.jl#L77

Added line #L77 was not covered by tests
end

function add_transform! end
function TO.select_backend(::typeof(add_transform!), C::AbstractTensorMap,
A::AbstractTensorMap)
return TensorKitBackend()
end
2 changes: 2 additions & 0 deletions src/tensors/blockiterator.jl
Original file line number Diff line number Diff line change
@@ -13,3 +13,5 @@ Base.IteratorEltype(::BlockIterator) = Base.HasEltype()
Base.eltype(::Type{<:BlockIterator{T}}) where {T} = blocktype(T)
Base.length(iter::BlockIterator) = length(iter.structure)
Base.isdone(iter::BlockIterator, state...) = Base.isdone(iter.structure, state...)

Base.haskey(iter::BlockIterator, c) = haskey(iter.structure, c)
50 changes: 44 additions & 6 deletions src/tensors/braidingtensor.jl
Original file line number Diff line number Diff line change
@@ -81,7 +81,7 @@
d = (dims(codomain(b), f₁.uncoupled)..., dims(domain(b), f₂.uncoupled)...)
n1 = d[1] * d[2]
n2 = d[3] * d[4]
data = sreshape(StridedView(Matrix{eltype(b)}(undef, n1, n2)), d)
data = sreshape(StridedView(blocktype(b)(undef, n1, n2)), d)
fill!(data, zero(eltype(b)))
if f₁.uncoupled == reverse(f₂.uncoupled)
braiddict = artin_braid(f₂, 1; inv=b.adjoint)
@@ -104,13 +104,27 @@
TensorMap(b::BraidingTensor) = copy!(similar(b), b)
Base.convert(::Type{TensorMap}, b::BraidingTensor) = TensorMap(b)

# Blocks iterator
# ---------------
blocks(b::BraidingTensor) = BlockIterator(b, blocksectors(b))
blocktype(::Type{TT}) where {TT<:BraidingTensor} = Matrix{eltype(TT)}

# TODO: efficient iterator
function Base.iterate(iter::BlockIterator{<:BraidingTensor}, state...)
next = iterate(iter.structure, state...)
isnothing(next) && return next
c, state = next
return c => block(iter.t, c), state
end
@inline Base.getindex(iter::BlockIterator{<:BraidingTensor}, c::Sector) = block(iter.t, c)

Check warning on line 119 in src/tensors/braidingtensor.jl

Codecov / codecov/patch

src/tensors/braidingtensor.jl#L119

Added line #L119 was not covered by tests

function block(b::BraidingTensor, s::Sector)
sectortype(b) == typeof(s) || throw(SectorMismatch())

# TODO: probably always square?
m = blockdim(codomain(b), s)
n = blockdim(domain(b), s)
data = Matrix{eltype(b)}(undef, (m, n))
data = blocktype(b)(undef, (m, n))

length(data) == 0 && return data # s ∉ blocksectors(b)

@@ -149,6 +163,30 @@
return data
end

# Linear Algebra
# --------------
function LinearAlgebra.mul!(C::AbstractTensorMap, A::AbstractTensorMap, B::BraidingTensor,

Check warning on line 168 in src/tensors/braidingtensor.jl

Codecov / codecov/patch

src/tensors/braidingtensor.jl#L168

Added line #L168 was not covered by tests
α::Number, β::Number)
compose(space(A), space(B)) == space(C) ||
throw(SpaceMismatch(lazy"$(space(C)) ≠ $(space(A)) * $(space(B))"))
levels = B.adjoint ? (1, 2, 3, 4) : (1, 2, 4, 3)
return add_braid!(C, A, ((1, 2), (4, 3)), levels, α, β)

Check warning on line 173 in src/tensors/braidingtensor.jl

Codecov / codecov/patch

src/tensors/braidingtensor.jl#L170-L173

Added lines #L170 - L173 were not covered by tests
end
function LinearAlgebra.mul!(C::AbstractTensorMap, A::BraidingTensor, B::AbstractTensorMap,

Check warning on line 175 in src/tensors/braidingtensor.jl

Codecov / codecov/patch

src/tensors/braidingtensor.jl#L175

Added line #L175 was not covered by tests
α::Number, β::Number)
compose(space(A), space(B)) == space(C) ||
throw(SpaceMismatch(lazy"$(space(C)) ≠ $(space(A)) * $(space(B))"))
levels = A.adjoint ? (2, 1, 3, 4) : (1, 2, 3, 4)
return add_transpose!(C, B, ((2, 1), (3, 4)), levels, α, β)

Check warning on line 180 in src/tensors/braidingtensor.jl

Codecov / codecov/patch

src/tensors/braidingtensor.jl#L177-L180

Added lines #L177 - L180 were not covered by tests
end
# TODO: implement this?
function LinearAlgebra.mul!(C::AbstractTensorMap, A::BraidingTensor, B::BraidingTensor,

Check warning on line 183 in src/tensors/braidingtensor.jl

Codecov / codecov/patch

src/tensors/braidingtensor.jl#L183

Added line #L183 was not covered by tests
α::Number, β::Number)
compose(space(A), space(B)) == space(C) ||
throw(SpaceMismatch(lazy"$(space(C)) ≠ $(space(A)) * $(space(B))"))
return mul!(C, TensorMap(A), B, α, β)

Check warning on line 187 in src/tensors/braidingtensor.jl

Codecov / codecov/patch

src/tensors/braidingtensor.jl#L185-L187

Added lines #L185 - L187 were not covered by tests
end

# Index manipulations
# -------------------
has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false
@@ -158,9 +196,9 @@
fusiontreetransform,
α::Number,
β::Number,
backend::AbstractBackend...)
backend::TensorKitBackend, allocator)
return add_transform!(tdst, TensorMap(tsrc), (p₁, p₂), fusiontreetransform, α, β,
backend...)
backend, allocator)
end

# VectorInterface
@@ -173,8 +211,8 @@

function TO.tensoradd!(C::AbstractTensorMap,
A::BraidingTensor, pA::Index2Tuple, conjA::Symbol,
α::Number, β::Number, backend=TO.DefaultBackend(),
allocator=TO.DefaultAllocator())
α::Number, β::Number, backend::AbstractBackend,
allocator)
return TO.tensoradd!(C, TensorMap(A), pA, conjA, α, β, backend, allocator)
end

Loading