From 8b3d04a9506ca78ee85488f917fb2d2b819304ad Mon Sep 17 00:00:00 2001 From: Hans Dembinski Date: Tue, 24 Sep 2024 18:24:44 +0200 Subject: [PATCH] fix: rvs in jit (#110) Closes #108 dist.rvs now also works in a compiled context. --------- Co-authored-by: Jan wagner Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hans Dembinski --- src/numba_stats/_util.py | 2 +- tests/test_norm.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/numba_stats/_util.py b/src/numba_stats/_util.py index 01dc7e9..d9b1d7f 100644 --- a/src/numba_stats/_util.py +++ b/src/numba_stats/_util.py @@ -187,7 +187,7 @@ def {fname}({args}): @_overload({fname}, inline="always") def _ol_{fname}({args}): - return {impl} + return {impl}.__wrapped__ """ else: code = f""" diff --git a/tests/test_norm.py b/tests/test_norm.py index f09b428..0556bc6 100644 --- a/tests/test_norm.py +++ b/tests/test_norm.py @@ -65,3 +65,12 @@ def test(x): y = test(x) assert_allclose(y, fn(x, 0, 1)) + + +@pytest.mark.filterwarnings("error") +def test_rvs_njit(): + @nb.njit + def test(): + return norm.rvs(0.0, 1.0, 10, 1) + + assert_allclose(test(), norm.rvs(0, 1, 10, 1))