-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathbackends.jl
84 lines (69 loc) · 3.15 KB
/
backends.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Scheduler implementation
# ------------------------
"""
const blockscheduler = ScopedValue{Scheduler}(SerialScheduler())
The default scheduler used when looping over different blocks in the matrix representation of a
tensor.
For controlling this value, see also [`set_blockscheduler`](@ref) and [`with_blockscheduler`](@ref).
"""
const blockscheduler = ScopedValue{Scheduler}(SerialScheduler())
"""
const subblockscheduler = ScopedValue{Scheduler}(SerialScheduler())
The default scheduler used when looping over different subblocks in a tensor.
For controlling this value, see also [`set_subblockscheduler`](@ref) and [`with_subblockscheduler`](@ref).
"""
const subblockscheduler = ScopedValue{Scheduler}(SerialScheduler())
function select_scheduler(scheduler=OhMyThreads.Implementation.NotGiven(); kwargs...)
return if scheduler == OhMyThreads.Implementation.NotGiven() && isempty(kwargs)
Threads.nthreads() == 1 ? SerialScheduler() : DynamicScheduler()
else
OhMyThreads.Implementation._scheduler_from_userinput(scheduler; kwargs...)
end
end
"""
with_blockscheduler(f, [scheduler]; kwargs...)
Run `f` in a scope where the `blockscheduler` is determined by `scheduler` and `kwargs...`.
See also [`with_subblockscheduler!`](@ref).
"""
@inline function with_blockscheduler(f, scheduler=OhMyThreads.Implementation.NotGiven();
kwargs...)
@with blockscheduler => select_scheduler(scheduler; kwargs...) f()
end
"""
with_subblockscheduler(f, [scheduler]; kwargs...)
Run `f` in a scope where the [`subblockscheduler`](@ref) is determined by `scheduler` and `kwargs...`.
The arguments to this function are either an `OhMyThreads.Scheduler` or a `Symbol` with optional
set of keywords arguments. For a detailed description, consult the
[`OhMyThreads` documentation](https://juliafolds2.github.io/OhMyThreads.jl/stable/refs/api/#Schedulers).
See also [`with_blockscheduler!`](@ref).
"""
@inline function with_subblockscheduler(f, scheduler=OhMyThreads.Implementation.NotGiven();
kwargs...)
@with subblockscheduler => select_scheduler(scheduler; kwargs...) f()
end
# Backend implementation
# ----------------------
# TODO: figure out a name
# TODO: what should be the default scheduler?
@kwdef struct TensorKitBackend{B<:AbstractBackend,BS,SBS} <: AbstractBackend
arraybackend::B = TO.DefaultBackend()
blockscheduler::BS = blockscheduler[]
subblockscheduler::SBS = subblockscheduler[]
end
function TO.select_backend(::typeof(TO.tensoradd!), C::AbstractTensorMap,
A::AbstractTensorMap)
return TensorKitBackend()
end
function TO.select_backend(::typeof(TO.tensortrace!), C::AbstractTensorMap,
A::AbstractTensorMap)
return TensorKitBackend()
end
function TO.select_backend(::typeof(TO.tensorcontract!), C::AbstractTensorMap,
A::AbstractTensorMap, B::AbstractTensorMap)
return TensorKitBackend()
end
function add_transform! end
function TO.select_backend(::typeof(add_transform!), C::AbstractTensorMap,
A::AbstractTensorMap)
return TensorKitBackend()
end