|
| 1 | +"""Tests for verifying process/thread usage in parallelized functions.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import time |
| 6 | +from collections.abc import Callable |
| 7 | +from functools import partial |
| 8 | + |
| 9 | +import dask.array as da |
| 10 | +import numba |
| 11 | +import numpy as np |
| 12 | +import psutil |
| 13 | +import pytest # type: ignore[import] |
| 14 | + |
| 15 | +from squidpy._utils import Signal, parallelize |
| 16 | + |
| 17 | +# Functions to be parallelized |
| 18 | + |
| 19 | + |
| 20 | +@numba.njit(parallel=True) |
| 21 | +def numba_parallel_func(x, y) -> np.ndarray: |
| 22 | + return x * 2 + y |
| 23 | + |
| 24 | + |
| 25 | +@numba.njit(parallel=False) |
| 26 | +def numba_serial_func(x, y) -> np.ndarray: |
| 27 | + return x * 2 + y |
| 28 | + |
| 29 | + |
| 30 | +def dask_func(x, y) -> np.ndarray: |
| 31 | + return (da.from_array(x) * 2 + y).compute() |
| 32 | + |
| 33 | + |
| 34 | +def vanilla_func(x, y) -> np.ndarray: |
| 35 | + return x * 2 + y |
| 36 | + |
| 37 | + |
| 38 | +# Mock runner function |
| 39 | + |
| 40 | + |
| 41 | +def mock_runner(x, y, queue, func): |
| 42 | + for i in range(len(x)): |
| 43 | + x[i] = func(x[i], y) |
| 44 | + if queue is not None: |
| 45 | + queue.put(Signal.UPDATE) |
| 46 | + if queue is not None: |
| 47 | + queue.put(Signal.FINISH) |
| 48 | + return x |
| 49 | + |
| 50 | + |
| 51 | +@pytest.fixture(params=["numba_parallel", "numba_serial", "dask", "vanilla"]) |
| 52 | +def func(request) -> Callable: |
| 53 | + return { |
| 54 | + "numba_parallel": numba_parallel_func, |
| 55 | + "numba_serial": numba_serial_func, |
| 56 | + "dask": dask_func, |
| 57 | + "vanilla": vanilla_func, |
| 58 | + }[request.param] |
| 59 | + |
| 60 | + |
| 61 | +@pytest.mark.timeout(60) |
| 62 | +@pytest.mark.parametrize("n_jobs", [1, 2, 8]) |
| 63 | +def test_parallelize_loky(func, n_jobs): |
| 64 | + start_time = time.time() |
| 65 | + seed = 42 |
| 66 | + rng = np.random.RandomState(seed) |
| 67 | + n = 8 |
| 68 | + arr1 = [rng.randint(0, 100, n) for _ in range(n)] |
| 69 | + arr2 = np.arange(n) |
| 70 | + runner = partial(mock_runner, func=func) |
| 71 | + # this is the expected result of the function |
| 72 | + expected = [func(arr1[i], arr2) for i in range(len(arr1))] |
| 73 | + # this will be set to something other than 1,2,8 |
| 74 | + # we want to check if setting the threads works |
| 75 | + # then after the function is run if the numba cores are set back to 1 |
| 76 | + old_num_threads = 3 |
| 77 | + numba.set_num_threads(old_num_threads) |
| 78 | + # Get initial state |
| 79 | + initial_process = psutil.Process() |
| 80 | + initial_children = {p.pid for p in initial_process.children(recursive=True)} |
| 81 | + initial_children = {psutil.Process(pid) for pid in initial_children} |
| 82 | + init_numba_threads = numba.get_num_threads() |
| 83 | + |
| 84 | + p_func = parallelize(runner, arr1, n_jobs=n_jobs, backend="loky", use_ixs=False, n_split=1) |
| 85 | + result = p_func(arr2)[0] |
| 86 | + |
| 87 | + final_children = {p.pid for p in initial_process.children(recursive=True)} |
| 88 | + final_numba_threads = numba.get_num_threads() |
| 89 | + |
| 90 | + assert init_numba_threads == old_num_threads, "Numba threads should not change" |
| 91 | + assert final_numba_threads == 1, "Numba threads should be 1" |
| 92 | + assert len(result) == len(expected), f"Expected: {expected} but got {result}. Length mismatch" |
| 93 | + for i in range(len(arr1)): |
| 94 | + assert np.all(result[i] == expected[i]), f"Expected {expected[i]} but got {result[i]}" |
| 95 | + |
| 96 | + processes = final_children - initial_children |
| 97 | + |
| 98 | + processes = {psutil.Process(pid) for pid in processes} |
| 99 | + processes = {p for p in processes if not any("resource_tracker" in cl for cl in p.cmdline())} |
| 100 | + if n_jobs > 1: # expect exactly n_jobs |
| 101 | + assert len(processes) == n_jobs, f"Unexpected processes created or not created: {processes}" |
| 102 | + else: # some functions use the main process others use a new process |
| 103 | + processes = {p for p in processes if p.create_time() > start_time} |
| 104 | + assert len(processes) <= 1, f"Unexpected processes created or not created: {processes}" |
0 commit comments