Skip to content

Commit

Permalink
Add rvs to several distributions (#81)
Browse files Browse the repository at this point in the history
Add rvs to compute random variates distributed like the corresponding
distribution for
- norm
- lognorm
- expon
- uniform
- t
- qgaussian
- truncexpon
- truncnorm

There are some limitations imposed numba compared to the scipy
implementations:
- Arguments size and random_state cannot be omitted
- random_state can only be None or int, it is not possible to pass an
instance of a numpy.random.Generator
- Generated variates are always of type float64
  • Loading branch information
HDembinski authored Oct 11, 2023
1 parent d3ee41f commit e258c53
Show file tree
Hide file tree
Showing 24 changed files with 210 additions and 52 deletions.
7 changes: 2 additions & 5 deletions bench/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,11 @@ def test_speed_bernstein_density(benchmark, lib, n):
def method(x, beta, xmin, xmax):
return BPoly(np.array(beta)[:, np.newaxis], [xmin, xmax])(x)

else:
elif lib == "ours:parallel,fastmath":
method = bernstein.density

if lib == "ours:parallel,fastmath":

@nb.njit(parallel=True, fastmath=True)
def method(x, beta, xmin, xmax):
return bernstein.density(x, beta, xmin, xmax)
method = nb.njit(parallel=True, fastmath=True)(method)

# warm-up JIT
method(x, beta, xmin, xmax)
Expand Down
82 changes: 71 additions & 11 deletions src/numba_stats/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from numba.core.errors import TypingError
from numba.extending import overload
from numba import prange as _prange # noqa
import os

_Floats = (nb.float32, nb.float64)

Expand All @@ -29,19 +30,35 @@ def _jit(arg, cache=True):
an array, the others are scalars and arg is the number of scalar arguments.
"""
if isinstance(arg, list):
return nb.njit(arg, cache=cache, inline="always", error_model="numpy")
signatures = arg
else:
signatures = []
for T in (nb.float32, nb.float64):
if arg < 0:
sig = T(*([T] * -arg))
else:
sig = T[:](_readonly_carray(T), *[T for _ in range(arg)])
signatures.append(sig)
return nb.njit(signatures, cache=cache, inline="always", error_model="numpy")

signatures = []
for T in (nb.float32, nb.float64):
if arg < 0:
sig = T(*([T] * -arg))
else:
sig = T[:](_readonly_carray(T), *[T for _ in range(arg)])
signatures.append(sig)

def _rvs_jit(arg, cache=True):
signatures = []
T = nb.float64 # nb.float32 cannot be supported
# extra args at the end are for size and random_state
sig = T[:](*[T for _ in range(arg)], nb.uint64, nb.optional(nb.uint64))
signatures.append(sig)
return nb.njit(signatures, cache=cache, inline="always", error_model="numpy")


@nb.njit(cache=True)
def _seed(seed):
if seed is None:
with nb.objmode(seed="optional(uint8)"):
seed = np.frombuffer(os.urandom(8), dtype=np.uint64)[0]
np.random.seed(seed)


def _wrap(fn):
def outer(first, *rest):
shape = np.shape(first)
Expand Down Expand Up @@ -80,7 +97,17 @@ def _generate_wrappers(d):

doc_par = d["_doc_par"].strip() if "_doc_par" in d else None

for fname in "pdf", "pmf", "logpdf", "logpmf", "cdf", "ppf", "density", "integral":
for fname in (
"pdf",
"pmf",
"logpdf",
"logpmf",
"cdf",
"ppf",
"density",
"integral",
"rvs",
):
impl = f"_{fname}"
if impl not in d:
continue
Expand All @@ -89,15 +116,47 @@ def _generate_wrappers(d):
args = ", ".join([f"{x}" for x in args])
doc_title = {
"density": "Return density.",
"integral": "Return integrated density.",
"logpdf": "Return log of probability density.",
"logpmf": "Return log of probability mass.",
"pmf": "Return probability mass.",
"pdf": "Return probability density.",
"cdf": "Return cumulative probability.",
"ppf": "Return quantile for given probability.",
"rvs": "Return random samples from distribution.",
}.get(fname, None)
if fname == "ppf":
before_par = """\
x: ArrayLike
Probability. Must be between 0 and 1.
"""
elif fname == "rvs":
before_par = ""
else:
before_par = """\
x: ArrayLike
Random variate.
"""
if fname == "rvs":
after_par = """
size : int
Number of random variates.
random_state : int or None
Seed of the random number generator. Default is None, which uses a random seed."""
else:
after_par = ""

if fname == "rvs":
code = f"""
def {fname}({args}):
return {impl}({args})
code = f"""
@_overload({fname}, inline="always")
def _ol_{fname}({args}):
return {impl}
"""
else:
code = f"""
def {fname}({args}):
return _wrap({impl})({args})
Expand All @@ -106,6 +165,7 @@ def _ol_{fname}({args}):
_type_check({args})
return {impl}.__wrapped__
"""

if doc_par is None:
code += f"""
{fname}.__doc__ = {impl}.__doc__
Expand All @@ -118,7 +178,7 @@ def _ol_{fname}({args}):
Parameters
----------
{doc_par}
{before_par}{doc_par}{after_par}
Returns
-------
Expand Down
2 changes: 0 additions & 2 deletions src/numba_stats/cpoisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
import numpy as np

_doc_par = """
x: ArrayLike
Random variate.
mu : float
Expected value.
"""
Expand Down
2 changes: 0 additions & 2 deletions src/numba_stats/cruijff.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import numpy as np

_doc_par = """
x : ArrayLike
Random variate.
beta_left: float
Left tail acceleration parameter.
beta_right: float
Expand Down
2 changes: 0 additions & 2 deletions src/numba_stats/crystalball_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import numpy as np

_doc_par = """
x : Array-like
Random variate.
beta_left : float
Distance from the mode in units of standard deviations where the Crystal Ball
turns from a gaussian into a power law on the left side.
Expand Down
11 changes: 8 additions & 3 deletions src/numba_stats/expon.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
"""
import numpy as np
from math import expm1 as _expm1, log1p as _log1p
from ._util import _jit, _trans, _generate_wrappers, _prange
from ._util import _jit, _trans, _generate_wrappers, _prange, _rvs_jit, _seed

_doc_par = """
x: ArrayLike
Random variate.
loc : float
Location of the mode.
scale : float
Expand Down Expand Up @@ -90,4 +88,11 @@ def _ppf(p, loc, scale):
return scale * z + loc


@_rvs_jit(2)
def _rvs(loc, scale, size, random_state):
_seed(random_state)
p = np.random.uniform(0, 1, size)
return _ppf(p, loc, scale)


_generate_wrappers(globals())
11 changes: 8 additions & 3 deletions src/numba_stats/lognorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
"""
import numpy as np
from . import norm as _norm
from ._util import _jit, _trans, _generate_wrappers, _prange
from ._util import _jit, _trans, _generate_wrappers, _prange, _seed, _rvs_jit

_doc_par = """
x : ArrayLike
Random variate.
s : float
Standard deviation of the corresponding normal distribution of exp(x).
loc : float
Expand Down Expand Up @@ -59,4 +57,11 @@ def _ppf(p, s, loc, scale):
return scale * r + loc


@_rvs_jit(3, cache=False)
def _rvs(s, loc, scale, size, random_state):
_seed(random_state)
p = np.random.uniform(0, 1, size)
return _ppf(p, s, loc, scale)


_generate_wrappers(globals())
11 changes: 8 additions & 3 deletions src/numba_stats/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@
"""
import numpy as np
from ._special import ndtri as _ndtri
from ._util import _jit, _trans, _generate_wrappers, _prange
from ._util import _jit, _trans, _generate_wrappers, _prange, _seed, _rvs_jit
from math import erf as _erf

_doc_par = """
x : ArrayLike
Random variate.
loc : float
Location of the mode of the distribution.
scale : float
Expand Down Expand Up @@ -68,4 +66,11 @@ def _ppf(p, loc, scale):
return r


@_rvs_jit(2, cache=False)
def _rvs(loc, scale, size, random_state):
_seed(random_state)
p = np.random.uniform(0, 1, size)
return _ppf(p, loc, scale)


_generate_wrappers(globals())
2 changes: 0 additions & 2 deletions src/numba_stats/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from ._util import _jit, _generate_wrappers, _prange

_doc_par = """
x : ArrayLike
Random variate.
mu : float
Expected value.
"""
Expand Down
17 changes: 14 additions & 3 deletions src/numba_stats/qgaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@
import numpy as np
import numba as nb
from . import norm as _norm, t as _t
from ._util import _jit, _generate_wrappers
from ._util import _jit, _generate_wrappers, _rvs_jit

_doc_par = """
x : ArrayLike
Random variate.
q : float
Shape parameter between 1 and 3. For q = 1, the qgaussian is a normal distribution,
for q == 3 it is a cauchy distribution.
Expand Down Expand Up @@ -87,4 +85,17 @@ def _ppf(p, q, mu, sigma):
return _t._ppf(p, df, mu, sigma)


@_rvs_jit(3, cache=False)
def _rvs(q, mu, sigma, size, random_state):
if q < 1 or q > 3:
raise ValueError("q < 1 or q > 3 are not supported")

if q == 1:
return _norm._rvs(mu, sigma, size, random_state)

df, sigma = _df_sigma(q, sigma)

return _t._rvs(df, mu, sigma, size, random_state)


_generate_wrappers(globals())
11 changes: 8 additions & 3 deletions src/numba_stats/t.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@
"""
import numpy as np
from ._special import stdtr as _stdtr, stdtrit as _stdtrit
from ._util import _jit, _trans, _generate_wrappers, _prange
from ._util import _jit, _trans, _generate_wrappers, _prange, _seed, _rvs_jit
from math import lgamma as _lgamma

_doc_par = """
x: ArrayLike
Random variate.
df : float
Degrees of freedom.
loc : float
Expand Down Expand Up @@ -62,4 +60,11 @@ def _ppf(p, df, loc, scale):
return scale * r + loc


@_rvs_jit(3, cache=False)
def _rvs(df, loc, scale, size, random_state):
_seed(random_state)
p = np.random.uniform(0, 1, size)
return _ppf(p, df, loc, scale)


_generate_wrappers(globals())
11 changes: 8 additions & 3 deletions src/numba_stats/truncexpon.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
scipy.stats.truncexpon: Scipy equivalent.
"""
import numpy as np
from ._util import _jit, _trans, _generate_wrappers, _prange
from ._util import _jit, _trans, _generate_wrappers, _prange, _rvs_jit, _seed
from . import expon as _expon

_doc_par = """
x: ArrayLike
Random variate.
xmin : float
Lower edge of the distribution.
xmax : float
Expand Down Expand Up @@ -77,4 +75,11 @@ def _ppf(p, xmin, xmax, loc, scale):
return z * scale + loc


@_rvs_jit(4)
def _rvs(xmin, xmax, loc, scale, size, random_state):
_seed(random_state)
p = np.random.uniform(0, 1, size)
return _ppf(p, xmin, xmax, loc, scale)


_generate_wrappers(globals())
11 changes: 8 additions & 3 deletions src/numba_stats/truncnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@

import numpy as np
from . import norm as _norm
from ._util import _jit, _generate_wrappers, _prange
from ._util import _jit, _generate_wrappers, _prange, _seed, _rvs_jit

_doc_par = """
x: ArrayLike
Random variate.
xmin : float
Lower edge of the distribution.
xmin : float
Expand Down Expand Up @@ -77,4 +75,11 @@ def _ppf(p, xmin, xmax, loc, scale):
return scale * r + loc


@_rvs_jit(4, cache=False)
def _rvs(xmin, xmax, loc, scale, size, random_state):
_seed(random_state)
p = np.random.uniform(0, 1, size)
return _ppf(p, xmin, xmax, loc, scale)


_generate_wrappers(globals())
2 changes: 0 additions & 2 deletions src/numba_stats/tsallis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from ._util import _jit, _generate_wrappers

_doc_par = """
x : ArrayLike
Random variate.
m : float
Mass of the particle.
t : float
Expand Down
Loading

0 comments on commit e258c53

Please sign in to comment.