Skip to content

Commit

Permalink
Potentially support parallelization (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
HDembinski authored Mar 14, 2022
1 parent 9413797 commit af1743f
Show file tree
Hide file tree
Showing 17 changed files with 174 additions and 184 deletions.
17 changes: 17 additions & 0 deletions benchmarks/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,20 @@ def test_t_ppf_speed(benchmark, which, n):
m = np.linspace(-1, 1, n)
s = np.linspace(0.1, 1, n)
benchmark(lambda: sc.t.ppf(x, df, m, s) if which == "scipy" else t.ppf(x, df, m, s))


@pytest.mark.benchmark(group="bernstein.density")
@pytest.mark.parametrize("which", ("scipy", "ours"))
@pytest.mark.parametrize("n", (10, 100, 1000, 10000))
def test_bernstein_density_speed(benchmark, which, n):
from numba_stats import bernstein
from scipy.interpolate import BPoly

x = np.linspace(0, 1, n)
beta = np.arange(1, 4, dtype=float)

benchmark(
lambda: BPoly(np.array(beta)[:, np.newaxis], [x[0], x[-1]])(x)
if which == "scipy"
else bernstein.density(x, beta, x[0], x[-1])
)
8 changes: 5 additions & 3 deletions src/numba_stats/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from numba.types import Array
from numba.core.errors import TypingError
from numba.extending import overload
from numba import prange as _prange # noqa

_Floats = (nb.float32, nb.float64)

Expand All @@ -13,7 +14,7 @@ def _readonly_carray(T):

def _jit(arg, cache=True):
if isinstance(arg, list):
return nb.njit(arg, cache=cache, error_model="numpy")
return nb.njit(arg, cache=cache, inline="always", error_model="numpy")

signatures = []
for T in (nb.float32, nb.float64):
Expand All @@ -23,7 +24,7 @@ def _jit(arg, cache=True):
sig = T[:](_readonly_carray(T), *[T for _ in range(arg)])
signatures.append(sig)

return nb.njit(signatures, cache=cache, error_model="numpy")
return nb.njit(signatures, cache=cache, inline="always", error_model="numpy")


def _wrap(fn):
Expand Down Expand Up @@ -79,11 +80,12 @@ def _generate_wrappers(d):
"cdf": "Return cumulative probability.",
"ppf": "Return quantile for given probability.",
}.get(fname, None)

code = f"""
def {fname}({args}):
return _wrap({impl})({args})
@_overload({fname})
@_overload({fname}, inline="always")
def _ol_{fname}({args}):
_type_check({args})
return {impl}.__wrapped__
Expand Down
22 changes: 10 additions & 12 deletions src/numba_stats/bernstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,16 @@
def _de_castlejau(z, beta):
# De Casteljau algorithm, numerically stable
n = len(beta)
res = np.empty_like(z)
if n == 0:
res[:] = np.nan
else:
betai = np.empty_like(beta)
for i, zi in enumerate(z):
azi = 1.0 - zi
betai[:] = beta
for j in range(1, n):
for k in range(n - j):
betai[k] = betai[k] * azi + betai[k + 1] * zi
res[i] = betai[0]
res = np.full_like(z, np.nan)
betai = np.empty_like(beta)
# not sure how to parallelize this, each worker thread needs its own betai
for i in range(len(z)):
betai[:] = beta
azi = 1.0 - z[i]
for j in range(1, n):
for k in range(n - j):
betai[k] = betai[k] * azi + betai[k + 1] * z[i]
res[i] = betai[0]
return res


Expand Down
7 changes: 4 additions & 3 deletions src/numba_stats/cpoisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
There is a Meijer G-function implemented in mpmath, but I don't know how to use it.
"""
from ._special import gammaincc as _gammaincc
from ._util import _jit, _generate_wrappers
from ._util import _jit, _generate_wrappers, _prange
import numpy as np

_doc_par = """
Expand All @@ -31,8 +31,9 @@
@_jit(1, cache=False)
def _cdf(x, mu):
r = np.empty_like(x)
for i, xi in enumerate(x):
r[i] = _gammaincc(xi + type(xi)(1), mu)
one = type(x[0])(1)
for i in _prange(len(x)):
r[i] = _gammaincc(x[i] + one, mu)
return r


Expand Down
29 changes: 14 additions & 15 deletions src/numba_stats/crystalball.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
--------
scipy.stats.crystalball: Scipy equivalent.
"""
from ._util import _jit, _trans, _generate_wrappers
from ._util import _jit, _trans, _generate_wrappers, _prange
import numpy as np
from math import erf as _erf

Expand All @@ -27,15 +27,17 @@

@_jit(-3)
def _log_powerlaw(z, beta, m):
c = -type(beta)(0.5) * beta * beta
T = type(beta)
c = -T(0.5) * beta * beta
log_a = m * np.log(m / beta) + c
b = m / beta - beta
return log_a - m * np.log(b - z)


@_jit(-3)
def _powerlaw_integral(z, beta, m):
exp_beta = np.exp(-type(beta)(0.5) * beta * beta)
T = type(beta)
exp_beta = np.exp(-T(0.5) * beta * beta)
a = (m / beta) ** m * exp_beta
b = m / beta - beta
m1 = m - type(m)(1)
Expand All @@ -44,12 +46,9 @@ def _powerlaw_integral(z, beta, m):

@_jit(-2)
def _normal_integral(a, b):
sqrt_half = np.sqrt(type(a)(0.5))
return (
sqrt_half
* np.sqrt(type(a)(np.pi))
* (_erf(b * sqrt_half) - _erf(a * sqrt_half))
)
T = type(a)
sqrt_half = np.sqrt(T(0.5))
return sqrt_half * np.sqrt(T(np.pi)) * (_erf(b * sqrt_half) - _erf(a * sqrt_half))


@_jit(-3)
Expand All @@ -66,8 +65,8 @@ def _logpdf(x, beta, m, loc, scale):
_powerlaw_integral(-beta, beta, m) + _normal_integral(-beta, type(beta)(np.inf))
)
c = np.log(norm)
for i, zi in enumerate(z):
z[i] = _log_density(zi, beta, m) - c
for i in _prange(len(z)):
z[i] = _log_density(z[i], beta, m) - c
return z


Expand All @@ -82,12 +81,12 @@ def _cdf(x, beta, m, loc, scale):
norm = _powerlaw_integral(-beta, beta, m) + _normal_integral(
-beta, type(beta)(np.inf)
)
for i, zi in enumerate(z):
if zi < -beta:
z[i] = _powerlaw_integral(zi, beta, m) / norm
for i in _prange(len(z)):
if z[i] < -beta:
z[i] = _powerlaw_integral(z[i], beta, m) / norm
else:
z[i] = (
_powerlaw_integral(-beta, beta, m) + _normal_integral(-beta, zi)
_powerlaw_integral(-beta, beta, m) + _normal_integral(-beta, z[i])
) / norm
return z

Expand Down
30 changes: 15 additions & 15 deletions src/numba_stats/crystalball_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

from .crystalball import _powerlaw_integral, _normal_integral, _log_density
from ._util import _jit, _generate_wrappers
from ._util import _jit, _generate_wrappers, _prange
import numpy as np

_doc_par = """
Expand All @@ -35,9 +35,8 @@

@_jit(-3)
def _norm_half(beta, m, scale):
return (
_powerlaw_integral(-beta, beta, m) + _normal_integral(-beta, type(beta)(0))
) * scale
T = type(beta)
return (_powerlaw_integral(-beta, beta, m) + _normal_integral(-beta, T(0))) * scale


@_jit(7)
Expand All @@ -47,15 +46,15 @@ def _logpdf(x, beta_left, m_left, scale_left, beta_right, m_right, scale_right,
)
c = np.log(norm)
r = np.empty_like(x)
for i, xi in enumerate(x):
if xi < loc:
for i in _prange(len(r)):
if x[i] < loc:
beta = beta_left
m = m_left
z = (xi - loc) * (type(scale_left)(1) / scale_left)
z = (x[i] - loc) / scale_left
else:
beta = beta_right
m = m_right
z = (loc - xi) * (type(scale_right)(1) / scale_right)
z = (loc - x[i]) / scale_right
r[i] = _log_density(z, beta, m) - c
return r

Expand All @@ -78,13 +77,14 @@ def _pdf(x, beta_left, m_left, scale_left, beta_right, m_right, scale_right, loc

@_jit(7)
def _cdf(x, beta_left, m_left, scale_left, beta_right, m_right, scale_right, loc):
T = type(beta_left)
norm = _norm_half(beta_left, m_left, scale_left) + _norm_half(
beta_right, m_right, scale_right
)
r = np.empty_like(x)
for i, xi in enumerate(x):
scale = type(scale_left)(1) / (scale_left if xi < loc else scale_right)
z = (xi - loc) * scale
for i in _prange(len(x)):
scale = T(1) / (scale_left if x[i] < loc else scale_right)
z = (x[i] - loc) * scale
if z < -beta_left:
r[i] = _powerlaw_integral(z, beta_left, m_left) * scale_left / norm
elif z < 0:
Expand All @@ -100,20 +100,20 @@ def _cdf(x, beta_left, m_left, scale_left, beta_right, m_right, scale_right, loc
r[i] = (
(
_powerlaw_integral(-beta_left, beta_left, m_left)
+ _normal_integral(-beta_left, type(beta_left)(0))
+ _normal_integral(-beta_left, T(0))
)
* scale_left
+ _normal_integral(0, z) * scale_right
+ _normal_integral(T(0), z) * scale_right
) / norm
else:
r[i] = (
(
_powerlaw_integral(-beta_left, beta_left, m_left)
+ _normal_integral(-beta_left, type(beta_left)(0))
+ _normal_integral(-beta_left, T(0))
)
* scale_left
+ (
_normal_integral(type(beta_right)(0), beta_right)
_normal_integral(T(0), beta_right)
+ _powerlaw_integral(-beta_right, beta_right, m_right)
- _powerlaw_integral(-z, beta_right, m_right)
)
Expand Down
10 changes: 5 additions & 5 deletions src/numba_stats/expon.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
import numpy as np
from math import expm1 as _expm1, log1p as _log1p
from ._util import _jit, _trans, _generate_wrappers
from ._util import _jit, _trans, _generate_wrappers, _prange

_doc_par = """
x: ArrayLike
Expand Down Expand Up @@ -77,16 +77,16 @@ def _cdf(x, loc, scale):
Function evaluated at x.
"""
z = _trans(x, loc, scale)
for i, zi in enumerate(z):
z[i] = _cdf1(zi)
for i in _prange(len(z)):
z[i] = _cdf1(z[i])
return z


@_jit(2)
def _ppf(p, loc, scale):
z = np.empty_like(p)
for i, pi in enumerate(p):
z[i] = _ppf1(pi)
for i in _prange(len(z)):
z[i] = _ppf1(p[i])
return scale * z + loc


Expand Down
25 changes: 12 additions & 13 deletions src/numba_stats/lognorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
import numpy as np
from . import norm as _norm
from ._util import _jit, _trans, _generate_wrappers
from ._util import _jit, _trans, _generate_wrappers, _prange

_doc_par = """
x : ArrayLike
Expand All @@ -24,10 +24,10 @@
@_jit(3)
def _logpdf(x, s, loc, scale):
r = _trans(x, loc, scale)
for i, ri in enumerate(r):
if ri > 0:
r[i] = -0.5 * np.log(ri) ** 2 / s**2 - np.log(
s * ri * np.sqrt(2 * np.pi) * scale
for i in _prange(len(r)):
if r[i] > 0:
r[i] = -0.5 * np.log(r[i]) ** 2 / s**2 - np.log(
s * r[i] * np.sqrt(2 * np.pi) * scale
)
else:
r[i] = -np.inf
Expand All @@ -42,22 +42,21 @@ def _pdf(x, s, loc, scale):
@_jit(3)
def _cdf(x, s, loc, scale):
r = _trans(x, loc, scale)
for i, ri in enumerate(r):
if ri <= 0:
for i in _prange(len(r)):
if r[i] <= 0:
r[i] = 0.0
else:
ri = np.log(ri) / s
r[i] = _norm._cdf1(ri)
z = np.log(r[i]) / s
r[i] = _norm._cdf1(z)
return r


@_jit(3, cache=False) # no cache because of norm._ppf
def _ppf(p, s, loc, scale):
r = np.empty_like(p)
for i in range(len(p)):
zi = np.exp(s * _norm._ppf1(p[i]))
r[i] = scale * zi + loc
return r
for i in _prange(len(p)):
r[i] = np.exp(s * _norm._ppf1(p[i]))
return scale * r + loc


_generate_wrappers(globals())
14 changes: 7 additions & 7 deletions src/numba_stats/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
import numpy as np
from ._special import erfinv as _erfinv
from ._util import _jit, _trans, _generate_wrappers
from ._util import _jit, _trans, _generate_wrappers, _prange
from math import erf as _erf

_doc_par = """
Expand Down Expand Up @@ -42,8 +42,8 @@ def _ppf1(p):
@_jit(2)
def _logpdf(x, loc, scale):
r = _trans(x, loc, scale)
for i, ri in enumerate(r):
r[i] = _logpdf1(ri) - np.log(scale)
for i in _prange(len(r)):
r[i] = _logpdf1(r[i]) - np.log(scale)
return r


Expand All @@ -55,16 +55,16 @@ def _pdf(x, loc, scale):
@_jit(2)
def _cdf(x, loc, scale):
r = _trans(x, loc, scale)
for i, ri in enumerate(r):
r[i] = _cdf1(ri)
for i in _prange(len(r)):
r[i] = _cdf1(r[i])
return r


@_jit(2, cache=False)
def _ppf(p, loc, scale):
r = np.empty_like(p)
for i, pi in enumerate(p):
r[i] = scale * _ppf1(pi) + loc
for i in _prange(len(r)):
r[i] = scale * _ppf1(p[i]) + loc
return r


Expand Down
Loading

0 comments on commit af1743f

Please sign in to comment.