23
23
24
24
def test_config ():
25
25
# mcc object
26
- mcc1 = MatthewsCorrelationCoefficient (num_classes = 1 )
27
- assert mcc1 .num_classes == 1
26
+ mcc1 = MatthewsCorrelationCoefficient (num_classes = 2 )
27
+ assert mcc1 .num_classes == 2
28
28
assert mcc1 .dtype == tf .float32
29
29
# check configure
30
30
mcc2 = MatthewsCorrelationCoefficient .from_config (mcc1 .get_config ())
31
- assert mcc2 .num_classes == 1
31
+ assert mcc2 .num_classes == 2
32
32
assert mcc2 .dtype == tf .float32
33
33
34
34
35
35
def check_results (obj , value ):
36
36
np .testing .assert_allclose (value , obj .result ().numpy (), atol = 1e-6 )
37
37
38
38
39
+ def test_binary_classes_sparse ():
40
+ gt_label = tf .constant ([[1.0 ], [1.0 ], [1.0 ], [0.0 ]], dtype = tf .float32 )
41
+ preds = tf .constant ([[1.0 ], [0.0 ], [1.0 ], [1.0 ]], dtype = tf .float32 )
42
+ # Initialize
43
+ mcc = MatthewsCorrelationCoefficient (1 )
44
+ # Update
45
+ mcc .update_state (gt_label , preds )
46
+ # Check results
47
+ check_results (mcc , [- 0.33333334 ])
48
+
49
+
39
50
def test_binary_classes ():
40
51
gt_label = tf .constant (
41
52
[[0.0 , 1.0 ], [0.0 , 1.0 ], [0.0 , 1.0 ], [1.0 , 0.0 ]], dtype = tf .float32
@@ -91,6 +102,16 @@ def test_multiple_classes():
91
102
sklearn_result = sklearn_matthew (gt_label .argmax (axis = 1 ), preds .argmax (axis = 1 ))
92
103
check_results (mcc , sklearn_result )
93
104
105
+ gt_label_sparse = tf .constant (
106
+ [[0.0 ], [2.0 ], [0.0 ], [2.0 ], [1.0 ], [1.0 ], [0.0 ], [0.0 ], [2.0 ], [1.0 ]]
107
+ )
108
+ preds_sparse = tf .constant (
109
+ [[2.0 ], [0.0 ], [2.0 ], [2.0 ], [2.0 ], [2.0 ], [2.0 ], [0.0 ], [2.0 ], [2.0 ]]
110
+ )
111
+ mcc = MatthewsCorrelationCoefficient (3 )
112
+ mcc .update_state (gt_label_sparse , preds_sparse )
113
+ check_results (mcc , sklearn_result )
114
+
94
115
95
116
# Keras model API check
96
117
def test_keras_model ():
@@ -110,13 +131,9 @@ def test_keras_model():
110
131
111
132
112
133
def test_reset_states_graph ():
113
- gt_label = tf .constant (
114
- [[0.0 , 1.0 ], [0.0 , 1.0 ], [0.0 , 1.0 ], [1.0 , 0.0 ]], dtype = tf .float32
115
- )
116
- preds = tf .constant (
117
- [[0.0 , 1.0 ], [1.0 , 0.0 ], [0.0 , 1.0 ], [0.0 , 1.0 ]], dtype = tf .float32
118
- )
119
- mcc = MatthewsCorrelationCoefficient (2 )
134
+ gt_label = tf .constant ([[1.0 ], [1.0 ], [1.0 ], [0.0 ]], dtype = tf .float32 )
135
+ preds = tf .constant ([[1.0 ], [0.0 ], [1.0 ], [1.0 ]], dtype = tf .float32 )
136
+ mcc = MatthewsCorrelationCoefficient (1 )
120
137
mcc .update_state (gt_label , preds )
121
138
122
139
@tf .function
0 commit comments