Skip to content

Commit 13913fd

Browse files
authoredNov 15, 2016
Merge pull request rushter#5 from lucaskolstad/metrics-cleanup
argument order fix for binary_crossentropy. metrics test refactoring. rename euclidian -> euclidean.
2 parents f220655 + 0c50c21 commit 13913fd

File tree

4 files changed

+43
-32
lines changed

4 files changed

+43
-32
lines changed
 

‎mla/kmeans.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66

77
from mla.base import BaseEstimator
8-
from mla.metrics.distance import euclidian_distance
8+
from mla.metrics.distance import euclidean_distance
99

1010
random.seed(1111)
1111

@@ -98,7 +98,7 @@ def _closest(self, fpoint, centroids):
9898
closest_index = None
9999
closest_distance = None
100100
for i, point in enumerate(centroids):
101-
dist = euclidian_distance(self.X[fpoint], point)
101+
dist = euclidean_distance(self.X[fpoint], point)
102102
if closest_index is None or dist < closest_distance:
103103
closest_index = i
104104
closest_distance = dist
@@ -109,7 +109,7 @@ def _get_centroid(self, cluster):
109109
return [np.mean(np.take(self.X[:, i], cluster)) for i in range(self.n_features)]
110110

111111
def _dist_from_centers(self):
112-
return np.array([min([euclidian_distance(x, c) for c in self.centroids]) for x in self.X])
112+
return np.array([min([euclidean_distance(x, c) for c in self.centroids]) for x in self.X])
113113

114114
def _choose_next_center(self):
115115
distances = self._dist_from_centers()
@@ -120,7 +120,7 @@ def _choose_next_center(self):
120120
return self.X[ind]
121121

122122
def _is_converged(self, centroids_old, centroids):
123-
return True if sum([euclidian_distance(centroids_old[i], centroids[i]) for i in range(self.K)]) == 0 else False
123+
return True if sum([euclidean_distance(centroids_old[i], centroids[i]) for i in range(self.K)]) == 0 else False
124124

125125
def plot(self, data=None):
126126
sns.set(style="white")

‎mla/metrics/distance.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import math
33

44

5-
def euclidian_distance(a, b):
5+
def euclidean_distance(a, b):
66
if isinstance(a, list) and isinstance(b, list):
77
a = np.array(a)
88
b = np.array(b)

‎mla/metrics/metrics.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,10 @@ def hinge(actual, predicted):
6767
return np.mean(np.max(1. - actual * predicted, 0.))
6868

6969

70-
def binary_crossentropy(predicted, actual):
70+
def binary_crossentropy(actual, predicted):
7171
predicted = np.clip(predicted, EPS, 1 - EPS)
72-
return np.mean(-np.sum(actual * np.log(predicted) + (1 - actual) * np.log(1 - predicted)))
72+
return np.mean(-np.sum(actual * np.log(predicted) +
73+
(1 - actual) * np.log(1 - predicted)))
7374

7475

7576
# aliases

‎mla/metrics/tests/test_metrics.py

+35-25
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from numpy.testing import assert_almost_equal
66

77
from mla.metrics.base import check_data, validate_input
8-
from mla.metrics.metrics import *
8+
from mla.metrics.metrics import get_metric
99

1010

1111
def test_data_validation():
@@ -26,53 +26,63 @@ def metric(name):
2626

2727

2828
def test_classification_error():
29-
assert metric('classification_error')([1, 2, 3, 4], [1, 2, 3, 4]) == 0
30-
assert metric('classification_error')([1, 2, 3, 4], [1, 2, 3, 5]) == 0.25
31-
assert metric('classification_error')([1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 0, 0]) == (1.0 / 6)
29+
f = metric('classification_error')
30+
assert f([1, 2, 3, 4], [1, 2, 3, 4]) == 0
31+
assert f([1, 2, 3, 4], [1, 2, 3, 5]) == 0.25
32+
assert f([1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 0, 0]) == (1.0 / 6)
3233

3334

3435
def test_absolute_error():
35-
assert metric('absolute_error')([3], [5]) == [2]
36-
assert metric('absolute_error')([-1], [-4]) == [3]
36+
f = metric('absolute_error')
37+
assert f([3], [5]) == [2]
38+
assert f([-1], [-4]) == [3]
3739

3840

3941
def test_mean_absolute_error():
40-
assert metric('mean_absolute_error')([1, 2, 3], [1, 2, 3]) == 0
41-
assert metric('mean_absolute_error')([1, 2, 3], [3, 2, 1]) == 4 / 3
42+
f = metric('mean_absolute_error')
43+
assert f([1, 2, 3], [1, 2, 3]) == 0
44+
assert f([1, 2, 3], [3, 2, 1]) == 4 / 3
4245

4346

4447
def test_squared_error():
45-
assert metric('squared_error')([1], [1]) == [0]
46-
assert metric('squared_error')([3], [1]) == [4]
48+
f = metric('squared_error')
49+
assert f([1], [1]) == [0]
50+
assert f([3], [1]) == [4]
4751

4852

4953
def test_squared_log_error():
50-
assert metric('squared_log_error')([1], [1]) == [0]
51-
assert metric('squared_log_error')([3], [1]) == [np.log(2) ** 2]
52-
assert metric('squared_log_error')([np.exp(2) - 1], [np.exp(1) - 1]) == [1.0]
54+
f = metric('squared_log_error')
55+
assert f([1], [1]) == [0]
56+
assert f([3], [1]) == [np.log(2) ** 2]
57+
assert f([np.exp(2) - 1], [np.exp(1) - 1]) == [1.0]
5358

5459

55-
def test_mean_squered_error():
56-
assert metric('mean_squared_log_error')([1, 2, 3], [1, 2, 3]) == 0
57-
assert metric('mean_squared_log_error')([1, 2, 3, np.exp(1) - 1], [1, 2, 3, np.exp(2) - 1]) == 0.25
60+
def test_mean_squared_log_error():
61+
f = metric('mean_squared_log_error')
62+
assert f([1, 2, 3], [1, 2, 3]) == 0
63+
assert f([1, 2, 3, np.exp(1) - 1], [1, 2, 3, np.exp(2) - 1]) == 0.25
5864

5965

6066
def test_root_mean_squared_log_error():
61-
assert metric('root_mean_squared_log_error')([1, 2, 3], [1, 2, 3]) == 0
62-
assert metric('root_mean_squared_log_error')([1, 2, 3, np.exp(1) - 1], [1, 2, 3, np.exp(2) - 1]) == 0.5
67+
f = metric('root_mean_squared_log_error')
68+
assert f([1, 2, 3], [1, 2, 3]) == 0
69+
assert f([1, 2, 3, np.exp(1) - 1], [1, 2, 3, np.exp(2) - 1]) == 0.5
6370

6471

6572
def test_mean_squared_error():
66-
assert metric('mean_squared_error')([1, 2, 3], [1, 2, 3]) == 0
67-
assert metric('mean_squared_error')(range(1, 5), [1, 2, 3, 6]) == 1
73+
f = metric('mean_squared_error')
74+
assert f([1, 2, 3], [1, 2, 3]) == 0
75+
assert f(range(1, 5), [1, 2, 3, 6]) == 1
6876

6977

7078
def test_root_mean_squared_error():
71-
assert metric('root_mean_squared_error')([1, 2, 3], [1, 2, 3]) == 0
72-
assert metric('root_mean_squared_error')(range(1, 5), [1, 2, 3, 5]) == 0.5
79+
f = metric('root_mean_squared_error')
80+
assert f([1, 2, 3], [1, 2, 3]) == 0
81+
assert f(range(1, 5), [1, 2, 3, 5]) == 0.5
7382

7483

7584
def test_multiclass_logloss():
76-
assert_almost_equal(metric('logloss')([1], [1]), 0)
77-
assert_almost_equal(metric('logloss')([1, 1], [1, 1]), 0)
78-
assert_almost_equal(metric('logloss')([1], [0.5]), -np.log(0.5))
85+
f = metric('logloss')
86+
assert_almost_equal(f([1], [1]), 0)
87+
assert_almost_equal(f([1, 1], [1, 1]), 0)
88+
assert_almost_equal(f([1], [0.5]), -np.log(0.5))

0 commit comments

Comments
 (0)
Please sign in to comment.