Skip to content

Commit 959b45d

Browse files
committed
feat: lazify classification metrics and clean tests
1 parent b292a49 commit 959b45d

File tree

2 files changed

+33
-40
lines changed

2 files changed

+33
-40
lines changed

ibis_ml/metrics.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def accuracy_score(y_true: dt.Integer, y_pred: dt.Integer) -> float:
1818
1919
Examples
2020
--------
21-
>>> from ibis_ml.metrics import accuracy_score
2221
>>> import ibis
22+
>>> from ibis_ml.metrics import accuracy_score
2323
>>> ibis.options.interactive = True
2424
>>> t = ibis.memtable(
2525
... {
@@ -29,9 +29,11 @@ def accuracy_score(y_true: dt.Integer, y_pred: dt.Integer) -> float:
2929
... }
3030
... )
3131
>>> accuracy_score(t.actual, t.prediction)
32-
0.5833333333333334
32+
┌──────────┐
33+
│ 0.583333 │
34+
└──────────┘
3335
"""
34-
return (y_true == y_pred).mean().to_pyarrow().as_py()
36+
return (y_true == y_pred).mean() # .to_pyarrow().as_py()
3537

3638

3739
def precision_score(y_true: dt.Integer, y_pred: dt.Integer) -> float:
@@ -62,11 +64,13 @@ def precision_score(y_true: dt.Integer, y_pred: dt.Integer) -> float:
6264
... }
6365
... )
6466
>>> precision_score(t.actual, t.prediction)
65-
0.6666666666666666
67+
┌──────────┐
68+
│ 0.666667 │
69+
└──────────┘
6670
"""
6771
true_positive = (y_true & y_pred).sum()
6872
predicted_positive = y_pred.sum()
69-
return (true_positive / predicted_positive).to_pyarrow().as_py()
73+
return true_positive / predicted_positive
7074

7175

7276
def recall_score(y_true: dt.Integer, y_pred: dt.Integer) -> float:
@@ -83,6 +87,7 @@ def recall_score(y_true: dt.Integer, y_pred: dt.Integer) -> float:
8387
-------
8488
float
8589
The recall score, representing the fraction of true positive predictions.
90+
8691
Examples
8792
--------
8893
>>> from ibis_ml.metrics import recall_score
@@ -96,11 +101,13 @@ def recall_score(y_true: dt.Integer, y_pred: dt.Integer) -> float:
96101
... }
97102
... )
98103
>>> recall_score(t.actual, t.prediction)
99-
0.5714285714285714
104+
┌──────────┐
105+
│ 0.571429 │
106+
└──────────┘
100107
"""
101108
true_positive = (y_true & y_pred).sum()
102109
actual_positive = y_true.sum()
103-
return (true_positive / actual_positive).to_pyarrow().as_py()
110+
return true_positive / actual_positive
104111

105112

106113
def f1_score(y_true: dt.Integer, y_pred: dt.Integer) -> float:
@@ -131,7 +138,9 @@ def f1_score(y_true: dt.Integer, y_pred: dt.Integer) -> float:
131138
... }
132139
... )
133140
>>> f1_score(t.actual, t.prediction)
134-
0.6153846153846154
141+
┌──────────┐
142+
│ 0.615385 │
143+
└──────────┘
135144
"""
136145
precision = precision_score(y_true, y_pred)
137146
recall = recall_score(y_true, y_pred)

tests/test_metrics.py

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
import ibis
22
import pytest
3-
from sklearn.metrics import accuracy_score as sk_accuracy_score
4-
from sklearn.metrics import f1_score as sk_f1_score
5-
from sklearn.metrics import precision_score as sk_precision_score
6-
from sklearn.metrics import recall_score as sk_recall_score
3+
import sklearn.metrics
74

8-
from ibis_ml.metrics import accuracy_score, f1_score, precision_score, recall_score
5+
import ibis_ml.metrics
96

107

118
@pytest.fixture
@@ -19,33 +16,20 @@ def results_table():
1916
)
2017

2118

22-
def test_accuracy_score(results_table):
19+
@pytest.mark.parametrize(
20+
"metric_name",
21+
[
22+
pytest.param("accuracy_score", id="accuracy_score"),
23+
pytest.param("precision_score", id="precision_score"),
24+
pytest.param("recall_score", id="recall_score"),
25+
pytest.param("f1_score", id="f1_score"),
26+
],
27+
)
28+
def test_classification_metrics(results_table, metric_name):
29+
ibis_ml_func = getattr(ibis_ml.metrics, metric_name)
30+
sklearn_func = getattr(sklearn.metrics, metric_name)
2331
t = results_table
2432
df = t.to_pandas()
25-
result = accuracy_score(t.actual, t.prediction)
26-
expected = sk_accuracy_score(df["actual"], df["prediction"])
27-
assert result == pytest.approx(expected, abs=1e-4)
28-
29-
30-
def test_precision_score(results_table):
31-
t = results_table
32-
df = t.to_pandas()
33-
result = precision_score(t.actual, t.prediction)
34-
expected = sk_precision_score(df["actual"], df["prediction"])
35-
assert result == pytest.approx(expected, abs=1e-4)
36-
37-
38-
def test_recall_score(results_table):
39-
t = results_table
40-
df = t.to_pandas()
41-
result = recall_score(t.actual, t.prediction)
42-
expected = sk_recall_score(df["actual"], df["prediction"])
43-
assert result == pytest.approx(expected, abs=1e-4)
44-
45-
46-
def test_f1_score(results_table):
47-
t = results_table
48-
df = t.to_pandas()
49-
result = f1_score(t.actual, t.prediction)
50-
expected = sk_f1_score(df["actual"], df["prediction"])
33+
result = ibis_ml_func(t.actual, t.prediction).to_pyarrow().as_py()
34+
expected = sklearn_func(df["actual"], df["prediction"])
5135
assert result == pytest.approx(expected, abs=1e-4)

0 commit comments

Comments
 (0)