Skip to content

Commit 349b75a

Browse files
Fix the problems with the latest merge (#1009)
* save working progress * add addopts from yml * prettify test file * don't set the original threads as the wrapper function already does that * make sure tox passes the PYTEST_ADDOPTS env * set inner_max_num_threads=1 * set inner_max_num_threads only for loky * fix blunder * ensure all function are called with n_job processes * remove inner_max_num_threads to simplify stuff * clarify numba behaviour in the docs * increase timeout limit * increase the timeout again * just set inner_max_num_threads instead * clarify in the docs that oversubscription is only handled for loky backend * simplify and clean the parallelize test * fix doc formatting * add non loky alternative to test also the other case * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * undo multiprocessing test option bc it may be too slow for the ci's. (at least they pass) * increase timeout limit * or instead reduce the computation required * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add other cases but skip in CI also reduce computation required again * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor skip markings --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a3b1b1e commit 349b75a

File tree

6 files changed

+77
-68
lines changed

6 files changed

+77
-68
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ jobs:
8888
MPLBACKEND: agg
8989
PLATFORM: ${{ matrix.os }}
9090
DISPLAY: :42
91+
PYTEST_ADDOPTS: "-n auto"
9192
run: |
9293
tox -vv
9394
# check if this can be deprecated

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ test = [
8888
"pytest-mock>=3.5.0",
8989
"pytest-cov>=4",
9090
"coverage[toml]>=7",
91-
"psutil",
91+
"pytest-timeout>=2.1.0",
9292
]
9393
docs = [
9494
"ipython",
@@ -231,5 +231,4 @@ ban-relative-imports = "all"
231231
[tool.pytest.ini_options]
232232
filterwarnings = [
233233
"error::numba.NumbaPerformanceWarning"
234-
]
235-
addopts = "-n auto"
234+
]

src/squidpy/_docs.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,14 @@ def decorator2(obj: Any) -> Any:
107107
"""
108108
_parallelize = """\
109109
n_jobs
110-
Number of parallel jobs to use. The number of cores used by numba will be set to 1 regardless of this argument
111-
since the backend will create a new process or thread for each job.
110+
Number of parallel jobs to use.
111+
For ``backend="loky"``, the number of cores used by numba for
112+
each job spawned by the backend will be set to 1 in order to
113+
overcome the oversubscription issue in case you run
114+
numba in your function to parallelize.
115+
To set the absolute maximum number of threads in numba
116+
for your python program, set the environment variable:
117+
``NUMBA_NUM_THREADS`` before running the program.
112118
backend
113119
Parallelization backend to use. See :class:`joblib.Parallel` for available options.
114120
show_progress_bar

src/squidpy/_utils.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -172,29 +172,29 @@ def update(pbar: tqdm.std.tqdm, queue: SigQueue, n_total: int) -> None:
172172
chosen_runner = runner if use_runner else callback
173173

174174
def wrapper(*args: Any, **kwargs: Any) -> Any:
175-
numba.set_num_threads(1)
176175
if pass_queue and show_progress_bar:
177176
pbar = None if tqdm is None else tqdm(total=col_len, unit=unit)
178177
queue = Manager().Queue()
179178
thread = Thread(target=update, args=(pbar, queue, len(collections)), name="ParallelizeUpdateThread")
180179
thread.start()
181180
else:
182181
pbar, queue, thread = None, None, None
183-
184-
res = jl.Parallel(n_jobs=n_jobs, backend=backend)(
185-
jl.delayed(_callback_wrapper)(
186-
*((chosen_runner, i, cs) if use_ixs else (chosen_runner, cs)),
187-
*args,
188-
**kwargs,
189-
queue=queue,
182+
jl_kwargs = {"inner_max_num_threads": 1} if backend == "loky" else {}
183+
with jl.parallel_config(backend, n_jobs=n_jobs, **jl_kwargs):
184+
res = jl.Parallel(n_jobs=n_jobs, backend=backend)(
185+
jl.delayed(_callback_wrapper)(
186+
*((chosen_runner, i, cs) if use_ixs else (chosen_runner, cs)),
187+
*args,
188+
**kwargs,
189+
queue=queue,
190+
)
191+
for i, cs in enumerate(collections)
190192
)
191-
for i, cs in enumerate(collections)
192-
)
193193

194-
if thread is not None:
195-
thread.join()
194+
if thread is not None:
195+
thread.join()
196196

197-
return res if extractor is None else extractor(res)
197+
return res if extractor is None else extractor(res)
198198

199199
if n_jobs is None:
200200
n_jobs = 1

tests/utils/test_parallelize.py

Lines changed: 52 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,26 @@
22

33
from __future__ import annotations
44

5-
import time
5+
import os
66
from collections.abc import Callable
77
from functools import partial
88

99
import dask.array as da
1010
import numba
1111
import numpy as np
12-
import psutil
1312
import pytest # type: ignore[import]
1413

1514
from squidpy._utils import Signal, parallelize
1615

1716
# Functions to be parallelized
1817

1918

19+
def wrap_numba_check(x, y, inner_function, check_threads=True):
20+
if check_threads:
21+
assert numba.get_num_threads() == 1
22+
return inner_function(x, y)
23+
24+
2025
@numba.njit(parallel=True)
2126
def numba_parallel_func(x, y) -> np.ndarray:
2227
return x * 2 + y
@@ -38,9 +43,9 @@ def vanilla_func(x, y) -> np.ndarray:
3843
# Mock runner function
3944

4045

41-
def mock_runner(x, y, queue, func):
42-
for i in range(len(x)):
43-
x[i] = func(x[i], y)
46+
def mock_runner(x, y, queue, function):
47+
for i, xi in enumerate(x):
48+
x[i] = function(xi, y, check_threads=True)
4449
if queue is not None:
4550
queue.put(Signal.UPDATE)
4651
if queue is not None:
@@ -51,54 +56,52 @@ def mock_runner(x, y, queue, func):
5156
@pytest.fixture(params=["numba_parallel", "numba_serial", "dask", "vanilla"])
5257
def func(request) -> Callable:
5358
return {
54-
"numba_parallel": numba_parallel_func,
55-
"numba_serial": numba_serial_func,
56-
"dask": dask_func,
57-
"vanilla": vanilla_func,
59+
"numba_parallel": partial(wrap_numba_check, inner_function=numba_parallel_func),
60+
"numba_serial": partial(wrap_numba_check, inner_function=numba_serial_func),
61+
"dask": partial(wrap_numba_check, inner_function=dask_func),
62+
"vanilla": partial(wrap_numba_check, inner_function=vanilla_func),
5863
}[request.param]
5964

6065

61-
@pytest.mark.timeout(60)
62-
@pytest.mark.parametrize("n_jobs", [1, 2, 8])
63-
def test_parallelize_loky(func, n_jobs):
64-
start_time = time.time()
66+
# Timeouts are also useful because some processes don't return in
67+
# in case of failure.
68+
69+
70+
@pytest.mark.timeout(30)
71+
@pytest.mark.parametrize(
72+
"backend",
73+
[
74+
pytest.param(
75+
"threading",
76+
marks=pytest.mark.skipif(
77+
os.environ.get("CI") == "true", reason="Only testing 'loky' backend in CI environment"
78+
),
79+
),
80+
pytest.param(
81+
"multiprocessing",
82+
marks=pytest.mark.skipif(
83+
os.environ.get("CI") == "true", reason="Only testing 'loky' backend in CI environment"
84+
),
85+
),
86+
"loky",
87+
],
88+
)
89+
def test_parallelize(func, backend):
6590
seed = 42
91+
n = 2
92+
n_jobs = 2
6693
rng = np.random.RandomState(seed)
67-
n = 8
6894
arr1 = [rng.randint(0, 100, n) for _ in range(n)]
6995
arr2 = np.arange(n)
70-
runner = partial(mock_runner, func=func)
71-
# this is the expected result of the function
72-
expected = [func(arr1[i], arr2) for i in range(len(arr1))]
73-
# this will be set to something other than 1,2,8
74-
# we want to check if setting the threads works
75-
# then after the function is run if the numba cores are set back to 1
76-
old_num_threads = 3
77-
numba.set_num_threads(old_num_threads)
78-
# Get initial state
79-
initial_process = psutil.Process()
80-
initial_children = {p.pid for p in initial_process.children(recursive=True)}
81-
initial_children = {psutil.Process(pid) for pid in initial_children}
82-
init_numba_threads = numba.get_num_threads()
83-
84-
p_func = parallelize(runner, arr1, n_jobs=n_jobs, backend="loky", use_ixs=False, n_split=1)
85-
result = p_func(arr2)[0]
86-
87-
final_children = {p.pid for p in initial_process.children(recursive=True)}
88-
final_numba_threads = numba.get_num_threads()
89-
90-
assert init_numba_threads == old_num_threads, "Numba threads should not change"
91-
assert final_numba_threads == 1, "Numba threads should be 1"
92-
assert len(result) == len(expected), f"Expected: {expected} but got {result}. Length mismatch"
93-
for i in range(len(arr1)):
94-
assert np.all(result[i] == expected[i]), f"Expected {expected[i]} but got {result[i]}"
95-
96-
processes = final_children - initial_children
97-
98-
processes = {psutil.Process(pid) for pid in processes}
99-
processes = {p for p in processes if not any("resource_tracker" in cl for cl in p.cmdline())}
100-
if n_jobs > 1: # expect exactly n_jobs
101-
assert len(processes) == n_jobs, f"Unexpected processes created or not created: {processes}"
102-
else: # some functions use the main process others use a new process
103-
processes = {p for p in processes if p.create_time() > start_time}
104-
assert len(processes) <= 1, f"Unexpected processes created or not created: {processes}"
96+
runner = partial(mock_runner, function=func)
97+
98+
init_threads = numba.get_num_threads()
99+
expected = np.vstack([func(a1, arr2, check_threads=False) for a1 in arr1])
100+
101+
p_func = parallelize(
102+
runner, arr1, n_jobs=n_jobs, backend=backend, use_ixs=False, extractor=np.vstack, show_progress=False
103+
)
104+
result = p_func(arr2)
105+
106+
assert numba.get_num_threads() == init_threads, "Number of threads should stay the same after parallelization"
107+
assert np.allclose(result, expected), f"Expected: {expected} but got {result}"

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ extras =
8080
interactive
8181
test
8282
setenv = linux: PYTEST_FLAGS=--test-napari
83-
passenv = TOXENV,CI,CODECOV_*,GITHUB_ACTIONS,PYTEST_FLAGS,DISPLAY,XAUTHORITY,MPLBACKEND
83+
passenv = TOXENV,CI,CODECOV_*,GITHUB_ACTIONS,PYTEST_FLAGS,DISPLAY,XAUTHORITY,MPLBACKEND,PYTEST_ADDOPTS
8484
usedevelop = true
8585
commands =
8686
python -m pytest --color=yes --cov --cov-append --cov-report=xml --cov-config={toxinidir}/tox.ini --ignore docs/ {posargs:-vv} {env:PYTEST_FLAGS:}

0 commit comments

Comments
 (0)