5
5
from numpy .testing import assert_almost_equal
6
6
7
7
from mla .metrics .base import check_data , validate_input
8
- from mla .metrics .metrics import *
8
+ from mla .metrics .metrics import get_metric
9
9
10
10
11
11
def test_data_validation ():
@@ -26,53 +26,63 @@ def metric(name):
26
26
27
27
28
28
def test_classification_error ():
29
- assert metric ('classification_error' )([1 , 2 , 3 , 4 ], [1 , 2 , 3 , 4 ]) == 0
30
- assert metric ('classification_error' )([1 , 2 , 3 , 4 ], [1 , 2 , 3 , 5 ]) == 0.25
31
- assert metric ('classification_error' )([1 , 1 , 1 , 0 , 0 , 0 ], [1 , 1 , 1 , 1 , 0 , 0 ]) == (1.0 / 6 )
29
+ f = metric ('classification_error' )
30
+ assert f ([1 , 2 , 3 , 4 ], [1 , 2 , 3 , 4 ]) == 0
31
+ assert f ([1 , 2 , 3 , 4 ], [1 , 2 , 3 , 5 ]) == 0.25
32
+ assert f ([1 , 1 , 1 , 0 , 0 , 0 ], [1 , 1 , 1 , 1 , 0 , 0 ]) == (1.0 / 6 )
32
33
33
34
34
35
def test_absolute_error ():
35
- assert metric ('absolute_error' )([3 ], [5 ]) == [2 ]
36
- assert metric ('absolute_error' )([- 1 ], [- 4 ]) == [3 ]
36
+ f = metric ('absolute_error' )
37
+ assert f ([3 ], [5 ]) == [2 ]
38
+ assert f ([- 1 ], [- 4 ]) == [3 ]
37
39
38
40
39
41
def test_mean_absolute_error ():
40
- assert metric ('mean_absolute_error' )([1 , 2 , 3 ], [1 , 2 , 3 ]) == 0
41
- assert metric ('mean_absolute_error' )([1 , 2 , 3 ], [3 , 2 , 1 ]) == 4 / 3
42
+ f = metric ('mean_absolute_error' )
43
+ assert f ([1 , 2 , 3 ], [1 , 2 , 3 ]) == 0
44
+ assert f ([1 , 2 , 3 ], [3 , 2 , 1 ]) == 4 / 3
42
45
43
46
44
47
def test_squared_error ():
45
- assert metric ('squared_error' )([1 ], [1 ]) == [0 ]
46
- assert metric ('squared_error' )([3 ], [1 ]) == [4 ]
48
+ f = metric ('squared_error' )
49
+ assert f ([1 ], [1 ]) == [0 ]
50
+ assert f ([3 ], [1 ]) == [4 ]
47
51
48
52
49
53
def test_squared_log_error ():
50
- assert metric ('squared_log_error' )([1 ], [1 ]) == [0 ]
51
- assert metric ('squared_log_error' )([3 ], [1 ]) == [np .log (2 ) ** 2 ]
52
- assert metric ('squared_log_error' )([np .exp (2 ) - 1 ], [np .exp (1 ) - 1 ]) == [1.0 ]
54
+ f = metric ('squared_log_error' )
55
+ assert f ([1 ], [1 ]) == [0 ]
56
+ assert f ([3 ], [1 ]) == [np .log (2 ) ** 2 ]
57
+ assert f ([np .exp (2 ) - 1 ], [np .exp (1 ) - 1 ]) == [1.0 ]
53
58
54
59
55
- def test_mean_squered_error ():
56
- assert metric ('mean_squared_log_error' )([1 , 2 , 3 ], [1 , 2 , 3 ]) == 0
57
- assert metric ('mean_squared_log_error' )([1 , 2 , 3 , np .exp (1 ) - 1 ], [1 , 2 , 3 , np .exp (2 ) - 1 ]) == 0.25
60
+ def test_mean_squared_log_error ():
61
+ f = metric ('mean_squared_log_error' )
62
+ assert f ([1 , 2 , 3 ], [1 , 2 , 3 ]) == 0
63
+ assert f ([1 , 2 , 3 , np .exp (1 ) - 1 ], [1 , 2 , 3 , np .exp (2 ) - 1 ]) == 0.25
58
64
59
65
60
66
def test_root_mean_squared_log_error ():
61
- assert metric ('root_mean_squared_log_error' )([1 , 2 , 3 ], [1 , 2 , 3 ]) == 0
62
- assert metric ('root_mean_squared_log_error' )([1 , 2 , 3 , np .exp (1 ) - 1 ], [1 , 2 , 3 , np .exp (2 ) - 1 ]) == 0.5
67
+ f = metric ('root_mean_squared_log_error' )
68
+ assert f ([1 , 2 , 3 ], [1 , 2 , 3 ]) == 0
69
+ assert f ([1 , 2 , 3 , np .exp (1 ) - 1 ], [1 , 2 , 3 , np .exp (2 ) - 1 ]) == 0.5
63
70
64
71
65
72
def test_mean_squared_error ():
66
- assert metric ('mean_squared_error' )([1 , 2 , 3 ], [1 , 2 , 3 ]) == 0
67
- assert metric ('mean_squared_error' )(range (1 , 5 ), [1 , 2 , 3 , 6 ]) == 1
73
+ f = metric ('mean_squared_error' )
74
+ assert f ([1 , 2 , 3 ], [1 , 2 , 3 ]) == 0
75
+ assert f (range (1 , 5 ), [1 , 2 , 3 , 6 ]) == 1
68
76
69
77
70
78
def test_root_mean_squared_error ():
71
- assert metric ('root_mean_squared_error' )([1 , 2 , 3 ], [1 , 2 , 3 ]) == 0
72
- assert metric ('root_mean_squared_error' )(range (1 , 5 ), [1 , 2 , 3 , 5 ]) == 0.5
79
+ f = metric ('root_mean_squared_error' )
80
+ assert f ([1 , 2 , 3 ], [1 , 2 , 3 ]) == 0
81
+ assert f (range (1 , 5 ), [1 , 2 , 3 , 5 ]) == 0.5
73
82
74
83
75
84
def test_multiclass_logloss ():
76
- assert_almost_equal (metric ('logloss' )([1 ], [1 ]), 0 )
77
- assert_almost_equal (metric ('logloss' )([1 , 1 ], [1 , 1 ]), 0 )
78
- assert_almost_equal (metric ('logloss' )([1 ], [0.5 ]), - np .log (0.5 ))
85
+ f = metric ('logloss' )
86
+ assert_almost_equal (f ([1 ], [1 ]), 0 )
87
+ assert_almost_equal (f ([1 , 1 ], [1 , 1 ]), 0 )
88
+ assert_almost_equal (f ([1 ], [0.5 ]), - np .log (0.5 ))
0 commit comments