Skip to content

Commit 09cc881

Browse files
committed
fix failing metrics tests under python3.5
1 parent 9b04774 commit 09cc881

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
build/
33
dist/
44
mla.egg-info/
5+
.cache

mla/metrics/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55

66
def check_data(a, b):
7-
if isinstance(a, list):
7+
if isinstance(a, list) or isinstance(a, range):
88
a = np.array(a)
99

10-
if isinstance(b, list):
10+
if isinstance(b, list) or isinstance(b, range):
1111
b = np.array(b)
1212

1313
if type(a) != type(b):

mla/metrics/metrics.py

-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def root_mean_squared_log_error(actual, predicted):
6060

6161
def logloss(actual, predicted):
6262
predicted = np.clip(predicted, EPS, 1 - EPS)
63-
predicted /= predicted.sum(axis=1, keepdims=True)
6463
loss = -np.sum(actual * np.log(predicted))
6564
return loss / float(actual.shape[0])
6665

0 commit comments

Comments
 (0)