|
23 | 23 |
|
24 | 24 | def test_config():
|
25 | 25 | # mcc object
|
26 |
| - mcc1 = MatthewsCorrelationCoefficient(num_classes=1) |
27 |
| - assert mcc1.num_classes == 1 |
| 26 | + mcc1 = MatthewsCorrelationCoefficient(num_classes=2) |
| 27 | + assert mcc1.num_classes == 2 |
28 | 28 | assert mcc1.dtype == tf.float32
|
29 | 29 | # check configure
|
30 | 30 | mcc2 = MatthewsCorrelationCoefficient.from_config(mcc1.get_config())
|
31 |
| - assert mcc2.num_classes == 1 |
| 31 | + assert mcc2.num_classes == 2 |
32 | 32 | assert mcc2.dtype == tf.float32
|
33 | 33 |
|
34 | 34 |
|
35 | 35 | def check_results(obj, value):
|
36 | 36 | np.testing.assert_allclose(value, obj.result().numpy(), atol=1e-6)
|
37 | 37 |
|
38 | 38 |
|
| 39 | +def test_binary_classes_sparse(): |
| 40 | + gt_label = tf.constant([[1.0], [1.0], [1.0], [0.0]], dtype=tf.float32) |
| 41 | + preds = tf.constant([[1.0], [0.0], [1.0], [1.0]], dtype=tf.float32) |
| 42 | + # Initialize |
| 43 | + mcc = MatthewsCorrelationCoefficient(1) |
| 44 | + # Update |
| 45 | + mcc.update_state(gt_label, preds) |
| 46 | + # Check results |
| 47 | + check_results(mcc, [-0.33333334]) |
| 48 | + |
| 49 | + |
39 | 50 | def test_binary_classes():
|
40 | 51 | gt_label = tf.constant(
|
41 | 52 | [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], dtype=tf.float32
|
@@ -91,6 +102,16 @@ def test_multiple_classes():
|
91 | 102 | sklearn_result = sklearn_matthew(gt_label.argmax(axis=1), preds.argmax(axis=1))
|
92 | 103 | check_results(mcc, sklearn_result)
|
93 | 104 |
|
| 105 | + gt_label_sparse = tf.constant( |
| 106 | + [[0.0], [2.0], [0.0], [2.0], [1.0], [1.0], [0.0], [0.0], [2.0], [1.0]] |
| 107 | + ) |
| 108 | + preds_sparse = tf.constant( |
| 109 | + [[2.0], [0.0], [2.0], [2.0], [2.0], [2.0], [2.0], [0.0], [2.0], [2.0]] |
| 110 | + ) |
| 111 | + mcc = MatthewsCorrelationCoefficient(3) |
| 112 | + mcc.update_state(gt_label_sparse, preds_sparse) |
| 113 | + check_results(mcc, sklearn_result) |
| 114 | + |
94 | 115 |
|
95 | 116 | # Keras model API check
|
96 | 117 | def test_keras_model():
|
|
0 commit comments