Skip to content

Commit 08d2a01

Browse files
authored
Merge pull request #20 from MothNik/fix/median_performance
Fix/median performance
2 parents 5bd08f9 + b805ff7 commit 08d2a01

File tree

10 files changed

+83
-27
lines changed

10 files changed

+83
-27
lines changed

.github/workflows/python-package.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ jobs:
2727
run: |
2828
python -m pip install --upgrade pip
2929
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
30+
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
3031
- name: Run tests
3132
run: |
3233
export PYTHONPATH="${PYTHONPATH}:/robustbase/"

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,13 @@ celerybeat.pid
103103

104104
# Environments
105105
.env
106-
.venv
106+
.venv*
107107
env/
108108
venv/
109109
ENV/
110110
env.bak/
111111
venv.bak/
112+
.vscode/
112113

113114
# Spyder project settings
114115
.spyderproject

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ This package provides functions to calculate the following robust statistical es
2121

2222
```python
2323
from robustbase.stats import Qn
24-
24+
2525
x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
2626

2727
# With bias correction
@@ -37,11 +37,11 @@ res = Qn(x, finite_corr=False) # result: 4.43828
3737

3838
```python
3939
from robustbase.stats import Sn
40-
40+
4141
x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
4242

4343
# With bias correction
44-
res = Sn(x) # result: 3.5778
44+
res = Sn(x) # result: 3.5778
4545

4646
# Without bias correction
4747
res = Sn(x, finite_corr=False) # result: 3.5778
@@ -75,7 +75,7 @@ For local development setup:
7575
```sh
7676
git clone https://github.com/deepak7376/robustbase
7777
cd robustbase
78-
pip install -r requirements.txt
78+
pip install -r requirements.txt -r requirements-dev.txt
7979
```
8080

8181
## Recent Changes

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pytest>=8.1.1

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,3 @@ certifi>=2019.11.28
22
docutils>=0.15.2
33
numpy>=1.18.0
44
statistics>=1.0.3.5
5-
pytest>=8.1.1

robustbase/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from .robustbase import Qn
2-
from .robustbase import Sn
3-
from .robustbase import iqr
4-
from .robustbase import mad
1+
from .robustbase import Qn # noqa: F401
2+
from .robustbase import Sn # noqa: F401
3+
from .robustbase import iqr # noqa: F401
4+
from .robustbase import mad # noqa: F401
55

66

robustbase/stats/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .Qn import Qn
2-
from .Sn import Sn
3-
from .iqr import iqr
4-
from .mad import mad
1+
from .iqr import iqr # noqa: F401
2+
from .mad import mad # noqa: F401
3+
from .Qn import Qn # noqa: F401
4+
from .Sn import Sn # noqa: F401

robustbase/utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .mean import mean
2-
from .median import median
1+
from .mean import mean # noqa: F401
2+
from .median import median # noqa: F401

robustbase/utils/median.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,29 @@ def median(x, low=False, high=False):
77
88
Parameters:
99
- x: list or array-like, numeric vector of observations.
10-
- low: bool, if True, return the low median for even sample size.
10+
- low: bool, if True, return the low median for even sample size. If ``True``, ``high`` is ignored.
1111
- high: bool, if True, return the high median for even sample size.
1212
1313
Returns:
1414
- float: Median value.
1515
"""
16-
sorted_x = np.sort(x)
17-
n = len(sorted_x)
16+
17+
n = len(x)
1818
if n == 0:
1919
raise ValueError("Empty list provided.")
20-
20+
21+
# for odd sample size, all three medians are the same
2122
if n % 2 == 1:
22-
return sorted_x[n // 2]
23-
elif low:
24-
return sorted_x[n // 2 - 1]
25-
elif high:
26-
return sorted_x[n // 2]
27-
else:
28-
return (sorted_x[n // 2 - 1] + sorted_x[n // 2]) / 2
23+
return np.median(a=x)
24+
25+
# for even sample sizes, the median is the average of the two middle values if
26+
# neither the low nor high median is requested
27+
if not (low or high):
28+
return np.median(a=x)
29+
30+
# otherwise, either the low or the high median are found via introselect
31+
median_idx = n // 2
32+
if low:
33+
median_idx -= 1
34+
35+
return np.partition(a=x, kth=median_idx)[median_idx]

tests/test_median.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import Optional, Tuple, Union
2+
3+
import numpy as np
4+
import pytest
5+
6+
from robustbase.utils.median import median
7+
8+
X_EMPTY = []
9+
X_ODD_N = [5.5, 3.2, -10.0, -2.1, 8.4]
10+
X_EVEN_N = [5.5, 3.2, -10.0, -2.1, 8.4, 0.0]
11+
12+
13+
@pytest.mark.parametrize("as_array", [False, True])
14+
@pytest.mark.parametrize(
15+
"comb",
16+
[
17+
(X_EMPTY, False, False, None),
18+
(X_EMPTY, True, False, None),
19+
(X_EMPTY, False, True, None),
20+
(X_EMPTY, True, True, None),
21+
(X_ODD_N, False, False, 3.2),
22+
(X_ODD_N, True, False, 3.2),
23+
(X_ODD_N, False, True, 3.2),
24+
(X_ODD_N, True, True, 3.2),
25+
(X_EVEN_N, False, False, 1.6),
26+
(X_EVEN_N, True, False, 0.0),
27+
(X_EVEN_N, False, True, 3.2),
28+
(X_EVEN_N, True, True, 0.0),
29+
],
30+
)
31+
def test_median(
32+
comb: Tuple[Union[list, np.ndarray], bool, bool, Optional[float]],
33+
as_array: bool,
34+
):
35+
x, low, high, expected = comb
36+
if as_array:
37+
x = np.array(x)
38+
39+
# for empty samples, an error should be raised
40+
if expected is None:
41+
with pytest.raises(ValueError):
42+
median(x, low=low, high=high)
43+
44+
return
45+
46+
# otherwise, the expected median should be returned
47+
assert median(x, low=low, high=high) == expected

0 commit comments

Comments
 (0)