Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-threading attempt III #203

Draft
wants to merge 18 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ version = "0.14.4"
[deps]
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
TensorKitSectors = "13a9c161-d5da-41f0-bcbd-e1a08ae0647f"
Expand All @@ -31,8 +33,10 @@ Combinatorics = "1"
FiniteDifferences = "0.12"
LRUCache = "1.0.2"
LinearAlgebra = "1"
OhMyThreads = "0.7.0"
PackageExtensionCompat = "1"
Random = "1"
ScopedValues = "1.3.0"
SparseArrays = "1"
Strided = "2"
TensorKitSectors = "0.1"
Expand Down
1 change: 0 additions & 1 deletion docs/src/lib/tensors.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ TensorKit.add_transpose!
```@docs
compose(::AbstractTensorMap, ::AbstractTensorMap)
trace_permute!
contract!
⊗(::AbstractTensorMap, ::AbstractTensorMap)
```

Expand Down
5 changes: 4 additions & 1 deletion src/TensorKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ export OrthogonalFactorizationAlgorithm, QR, QRpos, QL, QLpos, LQ, LQpos, RQ, RQ

# tensor operations
export @tensor, @tensoropt, @ncon, ncon, @planar, @plansor
export scalar, add!, contract!
export scalar, add!

# truncation schemes
export notrunc, truncerr, truncdim, truncspace, truncbelow
Expand All @@ -101,6 +101,8 @@ using TensorOperations: IndexTuple, Index2Tuple, linearize, AbstractBackend
const TO = TensorOperations

using LRUCache
using OhMyThreads
using ScopedValues

using TensorKitSectors
import TensorKitSectors: dim, BraidingStyle, FusionStyle, ⊠, ⊗
Expand Down Expand Up @@ -184,6 +186,7 @@ include("spaces/vectorspaces.jl")
#-------------------------------------
# general definitions
include("tensors/abstracttensor.jl")
include("tensors/backends.jl")
include("tensors/blockiterator.jl")
include("tensors/tensor.jl")
include("tensors/adjoint.jl")
Expand Down
3 changes: 2 additions & 1 deletion src/planar/planaroperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ function planarcontract!(C::AbstractTensorMap,
α::Number, β::Number,
backend, allocator)
if BraidingStyle(sectortype(C)) == Bosonic()
return contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator)
return TO.tensorcontract!(C, A, pA, false, B, pB, false, pAB,
α, β, backend, allocator)
end

codA, domA = codomainind(A), domainind(A)
Expand Down
107 changes: 107 additions & 0 deletions src/tensors/backends.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Scheduler implementation
# ------------------------
"""
const blockscheduler = ScopedValue{Scheduler}(SerialScheduler())

The default scheduler used when looping over different blocks in the matrix representation of a
tensor.

For controlling this value, see also [`set_blockscheduler`](@ref) and [`with_blockscheduler`](@ref).
"""
const blockscheduler = ScopedValue{Scheduler}(SerialScheduler())

"""
const subblockscheduler = ScopedValue{Scheduler}(SerialScheduler())

The default scheduler used when looping over different subblocks in a tensor.

For controlling this value, see also [`set_subblockscheduler`](@ref) and [`with_subblockscheduler`](@ref).
"""
const subblockscheduler = ScopedValue{Scheduler}(SerialScheduler())

function select_scheduler(scheduler=OhMyThreads.Implementation.NotGiven(); kwargs...)
return if scheduler == OhMyThreads.Implementation.NotGiven() && isempty(kwargs)
Threads.nthreads() == 1 ? SerialScheduler() : DynamicScheduler()

Check warning on line 24 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L22-L24

Added lines #L22 - L24 were not covered by tests
else
OhMyThreads.Implementation._scheduler_from_userinput(scheduler; kwargs...)

Check warning on line 26 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L26

Added line #L26 was not covered by tests
end
end

"""
set_blockscheduler!([scheduler]; kwargs...) -> previuos

Set the default scheduler used in looping over the different blocks in the matrix representation
of a tensor.
The arguments to this function are either an `OhMyThreads.Scheduler` or a `Symbol` with optional
set of keywords arguments. For a detailed description, consult the
[`OhMyThreads` documentation](https://juliafolds2.github.io/OhMyThreads.jl/stable/refs/api/#Schedulers).

See also [`with_blockscheduler`](@ref).
"""
function set_blockscheduler!(scheduler=OhMyThreads.Implementation.NotGiven(); kwargs...)
previous = blockscheduler[]
blockscheduler[] = select_scheduler(scheduler; kwargs...)
return previous

Check warning on line 44 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L41-L44

Added lines #L41 - L44 were not covered by tests
end

"""
with_blockscheduler(f, [scheduler]; kwargs...)

Run `f` in a scope where the `blockscheduler` is determined by `scheduler` and `kwargs...`.

See also [`set_blockscheduler!`](@ref).
"""
function with_blockscheduler(f, scheduler=OhMyThreads.Implementation.NotGiven(); kwargs...)
@with blockscheduler => select_scheduler(scheduler; kwargs...) f()

Check warning on line 55 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L54-L55

Added lines #L54 - L55 were not covered by tests
end

"""
set_subblockscheduler!([scheduler]; kwargs...) -> previous

Set the default scheduler used in looping over the different subblocks in a tensor.
The arguments to this function are either an `OhMyThreads.Scheduler` or a `Symbol` with optional
set of keywords arguments. For a detailed description, consult the
[`OhMyThreads` documentation](https://juliafolds2.github.io/OhMyThreads.jl/stable/refs/api/#Schedulers).

See also [`with_subblockscheduler`](@ref).
"""
function set_subblockscheduler!(scheduler=OhMyThreads.Implementation.NotGiven(); kwargs...)
previous = subblockscheduler[]
subblockscheduler[] = select_scheduler(scheduler; kwargs...)
return previous

Check warning on line 71 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L68-L71

Added lines #L68 - L71 were not covered by tests
end

"""
with_subblockscheduler(f, [scheduler]; kwargs...)

Run `f` in a scope where the [`subblockscheduler`](@ref) is determined by `scheduler` and `kwargs...`.

See also [`set_subblockscheduler!`](@ref).
"""
function with_subblockscheduler(f, scheduler=OhMyThreads.Implementation.NotGiven();

Check warning on line 81 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L81

Added line #L81 was not covered by tests
kwargs...)
@with subblockscheduler => select_scheduler(scheduler; kwargs...) f()

Check warning on line 83 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L83

Added line #L83 was not covered by tests
end

# Backend implementation
# ----------------------
# TODO: figure out a name
# TODO: what should be the default scheduler?
@kwdef struct TensorKitBackend{B<:AbstractBackend,BS,SBS} <: AbstractBackend
arraybackend::B = TO.DefaultBackend()
blockscheduler::BS = blockscheduler[]
subblockscheduler::SBS = subblockscheduler[]
end

function TO.select_backend(::typeof(TO.tensoradd!), C::AbstractTensorMap,

Check warning on line 96 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L96

Added line #L96 was not covered by tests
A::AbstractTensorMap)
return TensorKitBackend()

Check warning on line 98 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L98

Added line #L98 was not covered by tests
end
function TO.select_backend(::typeof(TO.tensortrace!), C::AbstractTensorMap,

Check warning on line 100 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L100

Added line #L100 was not covered by tests
A::AbstractTensorMap)
return TensorKitBackend()

Check warning on line 102 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L102

Added line #L102 was not covered by tests
end
function TO.select_backend(::typeof(TO.tensorcontract!), C::AbstractTensorMap,

Check warning on line 104 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L104

Added line #L104 was not covered by tests
A::AbstractTensorMap, B::AbstractTensorMap)
return TensorKitBackend()

Check warning on line 106 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L106

Added line #L106 was not covered by tests
end
2 changes: 2 additions & 0 deletions src/tensors/blockiterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ Base.IteratorEltype(::BlockIterator) = Base.HasEltype()
Base.eltype(::Type{<:BlockIterator{T}}) where {T} = blocktype(T)
Base.length(iter::BlockIterator) = length(iter.structure)
Base.isdone(iter::BlockIterator, state...) = Base.isdone(iter.structure, state...)

Base.haskey(iter::BlockIterator, c) = haskey(iter.structure, c)
8 changes: 4 additions & 4 deletions src/tensors/braidingtensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ function add_transform!(tdst::AbstractTensorMap,
fusiontreetransform,
α::Number,
β::Number,
backend::AbstractBackend...)
backend::TensorKitBackend, allocator)
return add_transform!(tdst, TensorMap(tsrc), (p₁, p₂), fusiontreetransform, α, β,
backend...)
backend, allocator)
end

# VectorInterface
Expand All @@ -173,8 +173,8 @@ end

function TO.tensoradd!(C::AbstractTensorMap,
A::BraidingTensor, pA::Index2Tuple, conjA::Symbol,
α::Number, β::Number, backend=TO.DefaultBackend(),
allocator=TO.DefaultAllocator())
α::Number, β::Number, backend::AbstractBackend,
allocator)
return TO.tensoradd!(C, TensorMap(A), pA, conjA, α, β, backend, allocator)
end

Expand Down
Loading
Loading