Skip to content

Commit a3b1b1e

Browse files
parallelize sets numba cores to 1 (#1008)
* init * save working progress * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * finalize the tests * finish tests and update docs * fix pyproject * fix serialization issue * remove test_parallelize to check the new speed of tests * add parallelize tests again * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * replace isolate * use xdist --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8b74755 commit a3b1b1e

File tree

4 files changed

+120
-5
lines changed

4 files changed

+120
-5
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ test = [
8888
"pytest-mock>=3.5.0",
8989
"pytest-cov>=4",
9090
"coverage[toml]>=7",
91+
"psutil",
9192
]
9293
docs = [
9394
"ipython",
@@ -231,3 +232,4 @@ ban-relative-imports = "all"
231232
filterwarnings = [
232233
"error::numba.NumbaPerformanceWarning"
233234
]
235+
addopts = "-n auto"

src/squidpy/_docs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ def decorator2(obj: Any) -> Any:
107107
"""
108108
_parallelize = """\
109109
n_jobs
110-
Number of parallel jobs to use. If the function uses numba compiled functions, numba may
111-
use cores depending on the number of threads set in the environment regardless of this argument.
110+
Number of parallel jobs to use. The number of cores used by numba will be set to 1 regardless of this argument
111+
since the backend will create a new process or thread for each job.
112112
backend
113113
Parallelization backend to use. See :class:`joblib.Parallel` for available options.
114114
show_progress_bar

src/squidpy/_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing import TYPE_CHECKING, Any
1515

1616
import joblib as jl
17+
import numba
1718
import numpy as np
1819

1920
__all__ = ["singledispatchmethod", "Signal", "SigQueue", "NDArray", "NDArrayA"]
@@ -55,6 +56,11 @@ def _unique_order_preserving(
5556
return [i for i in iterable if not (i in seen or seen_add(i))], seen
5657

5758

59+
def _callback_wrapper(chosen_runner: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
60+
numba.set_num_threads(1)
61+
return chosen_runner(*args, **kwargs)
62+
63+
5864
class Signal(Enum):
5965
"""Signaling values when informing parallelizer."""
6066

@@ -163,18 +169,21 @@ def update(pbar: tqdm.std.tqdm, queue: SigQueue, n_total: int) -> None:
163169
if pbar is not None:
164170
pbar.close()
165171

172+
chosen_runner = runner if use_runner else callback
173+
166174
def wrapper(*args: Any, **kwargs: Any) -> Any:
175+
numba.set_num_threads(1)
167176
if pass_queue and show_progress_bar:
168177
pbar = None if tqdm is None else tqdm(total=col_len, unit=unit)
169178
queue = Manager().Queue()
170-
thread = Thread(target=update, args=(pbar, queue, len(collections)))
179+
thread = Thread(target=update, args=(pbar, queue, len(collections)), name="ParallelizeUpdateThread")
171180
thread.start()
172181
else:
173182
pbar, queue, thread = None, None, None
174183

175184
res = jl.Parallel(n_jobs=n_jobs, backend=backend)(
176-
jl.delayed(runner if use_runner else callback)(
177-
*((i, cs) if use_ixs else (cs,)),
185+
jl.delayed(_callback_wrapper)(
186+
*((chosen_runner, i, cs) if use_ixs else (chosen_runner, cs)),
178187
*args,
179188
**kwargs,
180189
queue=queue,

tests/utils/test_parallelize.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""Tests for verifying process/thread usage in parallelized functions."""
2+
3+
from __future__ import annotations
4+
5+
import time
6+
from collections.abc import Callable
7+
from functools import partial
8+
9+
import dask.array as da
10+
import numba
11+
import numpy as np
12+
import psutil
13+
import pytest # type: ignore[import]
14+
15+
from squidpy._utils import Signal, parallelize
16+
17+
# Functions to be parallelized
18+
19+
20+
@numba.njit(parallel=True)
21+
def numba_parallel_func(x, y) -> np.ndarray:
22+
return x * 2 + y
23+
24+
25+
@numba.njit(parallel=False)
26+
def numba_serial_func(x, y) -> np.ndarray:
27+
return x * 2 + y
28+
29+
30+
def dask_func(x, y) -> np.ndarray:
31+
return (da.from_array(x) * 2 + y).compute()
32+
33+
34+
def vanilla_func(x, y) -> np.ndarray:
35+
return x * 2 + y
36+
37+
38+
# Mock runner function
39+
40+
41+
def mock_runner(x, y, queue, func):
42+
for i in range(len(x)):
43+
x[i] = func(x[i], y)
44+
if queue is not None:
45+
queue.put(Signal.UPDATE)
46+
if queue is not None:
47+
queue.put(Signal.FINISH)
48+
return x
49+
50+
51+
@pytest.fixture(params=["numba_parallel", "numba_serial", "dask", "vanilla"])
52+
def func(request) -> Callable:
53+
return {
54+
"numba_parallel": numba_parallel_func,
55+
"numba_serial": numba_serial_func,
56+
"dask": dask_func,
57+
"vanilla": vanilla_func,
58+
}[request.param]
59+
60+
61+
@pytest.mark.timeout(60)
62+
@pytest.mark.parametrize("n_jobs", [1, 2, 8])
63+
def test_parallelize_loky(func, n_jobs):
64+
start_time = time.time()
65+
seed = 42
66+
rng = np.random.RandomState(seed)
67+
n = 8
68+
arr1 = [rng.randint(0, 100, n) for _ in range(n)]
69+
arr2 = np.arange(n)
70+
runner = partial(mock_runner, func=func)
71+
# this is the expected result of the function
72+
expected = [func(arr1[i], arr2) for i in range(len(arr1))]
73+
# this will be set to something other than 1,2,8
74+
# we want to check if setting the threads works
75+
# then after the function is run if the numba cores are set back to 1
76+
old_num_threads = 3
77+
numba.set_num_threads(old_num_threads)
78+
# Get initial state
79+
initial_process = psutil.Process()
80+
initial_children = {p.pid for p in initial_process.children(recursive=True)}
81+
initial_children = {psutil.Process(pid) for pid in initial_children}
82+
init_numba_threads = numba.get_num_threads()
83+
84+
p_func = parallelize(runner, arr1, n_jobs=n_jobs, backend="loky", use_ixs=False, n_split=1)
85+
result = p_func(arr2)[0]
86+
87+
final_children = {p.pid for p in initial_process.children(recursive=True)}
88+
final_numba_threads = numba.get_num_threads()
89+
90+
assert init_numba_threads == old_num_threads, "Numba threads should not change"
91+
assert final_numba_threads == 1, "Numba threads should be 1"
92+
assert len(result) == len(expected), f"Expected: {expected} but got {result}. Length mismatch"
93+
for i in range(len(arr1)):
94+
assert np.all(result[i] == expected[i]), f"Expected {expected[i]} but got {result[i]}"
95+
96+
processes = final_children - initial_children
97+
98+
processes = {psutil.Process(pid) for pid in processes}
99+
processes = {p for p in processes if not any("resource_tracker" in cl for cl in p.cmdline())}
100+
if n_jobs > 1: # expect exactly n_jobs
101+
assert len(processes) == n_jobs, f"Unexpected processes created or not created: {processes}"
102+
else: # some functions use the main process others use a new process
103+
processes = {p for p in processes if p.create_time() > start_time}
104+
assert len(processes) <= 1, f"Unexpected processes created or not created: {processes}"

0 commit comments

Comments
 (0)