Skip to content

Commit 22739bb

Browse files
committed
Add functions for controlling schedulers
1 parent cfc91a9 commit 22739bb

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed

src/tensors/backends.jl

+80
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,88 @@
11
# Scheduler implementation
22
# ------------------------
3+
"""
4+
const blockscheduler = ScopedValue{Scheduler}(SerialScheduler())
5+
6+
The default scheduler used when looping over different blocks in the matrix representation of a
7+
tensor.
8+
9+
For controlling this value, see also [`set_blockscheduler`](@ref) and [`with_blockscheduler`](@ref).
10+
"""
311
const blockscheduler = ScopedValue{Scheduler}(SerialScheduler())
12+
13+
"""
14+
cosnt subblockscheduler = ScopedValue{Scheduler}(SerialScheduler())
15+
16+
The default scheduler used when looping over different subblocks in a tensor.
17+
18+
For controlling this value, see also [`set_subblockscheduler`](@ref) and [`with_subblockscheduler`](@ref).
19+
"""
420
const subblockscheduler = ScopedValue{Scheduler}(SerialScheduler())
521

22+
function select_scheduler(scheduler=OhMyThreads.Implementation.NotGiven(); kwargs...)
23+
return if scheduler == OhMyThreads.Implementation.NotGiven() && isempty(kwargs)
24+
Threads.nthreads() == 1 ? SerialScheduler() : DynamicScheduler()
25+
else
26+
OhMyThreads.Implementation._scheduler_from_userinput(scheduler; kwargs...)
27+
end
28+
end
29+
30+
"""
31+
set_blockscheduler!([scheduler]; kwargs...) -> previuos
32+
33+
Set the default scheduler used in looping over the different blocks in the matrix representation
34+
of a tensor.
35+
The arguments to this function are either an `OhMyThreads.Scheduler` or a `Symbol` with optional
36+
set of keywords arguments. For a detailed description, consult the
37+
[`OhMyThreads` documentation](https://juliafolds2.github.io/OhMyThreads.jl/stable/refs/api/#Schedulers).
38+
39+
See also [`with_blockscheduler`](@ref).
40+
"""
41+
function set_blockscheduler!(scheduler=OhMyThreads.Implementation.NotGiven(); kwargs...)
42+
previous = blockscheduler[]
43+
blockscheduler[] = select_scheduler(scheduler; kwargs...)
44+
return previous
45+
end
46+
47+
"""
48+
with_blockscheduler(f, [scheduler]; kwargs...)
49+
50+
Run `f` in a scope where the `blockscheduler` is determined by `scheduler` and `kwargs...`.
51+
52+
See also [`set_blockscheduler!`](@ref).
53+
"""
54+
function with_blockscheduler(f, scheduler=OhMyThreads.Implementation.NotGiven(); kwargs...)
55+
@with blockscheduler => select_scheduler(scheduler; kwargs...) f()
56+
end
57+
58+
"""
59+
set_subblockscheduler!([scheduler]; kwargs...) -> previous
60+
61+
Set the default scheduler used in looping over the different subblocks in a tensor.
62+
The arguments to this function are either an `OhMyThreads.Scheduler` or a `Symbol` with optional
63+
set of keywords arguments. For a detailed description, consult the
64+
[`OhMyThreads` documentation](https://juliafolds2.github.io/OhMyThreads.jl/stable/refs/api/#Schedulers).
65+
66+
See also [`with_subblockscheduler`](@ref).
67+
"""
68+
function set_subblockscheduler!(scheduler=OhMyThreads.Implementation.NotGiven(); kwargs...)
69+
previous = subblockscheduler[]
70+
subblockscheduler[] = select_scheduler(scheduler; kwargs...)
71+
return previous
72+
end
73+
74+
"""
75+
with_subblockscheduler(f, [scheduler]; kwargs...)
76+
77+
Run `f` in a scope where the [`subblockscheduler`](@ref) is determined by `scheduler` and `kwargs...`.
78+
79+
See also [`set_subblockscheduler!`](@ref).
80+
"""
81+
function with_subblockscheduler(f, scheduler=OhMyThreads.Implementation.NotGiven();
82+
kwargs...)
83+
@with subblockscheduler => select_scheduler(scheduler; kwargs...) f()
84+
end
85+
686
# Backend implementation
787
# ----------------------
888
# TODO: figure out a name

0 commit comments

Comments
 (0)