Skip to content

Commit 9b7fb79

Browse files
committed
Add support for sparse labels in MatthewsCorrelationCoefficient
1 parent 97eb293 commit 9b7fb79

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

tensorflow_addons/metrics/matthews_correlation_coefficient.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(
6969
):
7070
"""Creates a Matthews Correlation Coefficient instance."""
7171
super().__init__(name=name, dtype=dtype)
72-
self.num_classes = num_classes
72+
self.num_classes = max(2, num_classes)
7373
self.conf_mtx = self.add_weight(
7474
"conf_mtx",
7575
shape=(self.num_classes, self.num_classes),
@@ -82,9 +82,19 @@ def update_state(self, y_true, y_pred, sample_weight=None):
8282
y_true = tf.cast(y_true, dtype=self.dtype)
8383
y_pred = tf.cast(y_pred, dtype=self.dtype)
8484

85+
if y_true.shape[-1] == 1:
86+
labels = tf.squeeze(tf.round(y_true), axis=-1)
87+
else:
88+
labels = tf.argmax(y_true, 1)
89+
90+
if y_pred.shape[-1] == 1:
91+
predictions = tf.squeeze(tf.round(y_pred), axis=-1)
92+
else:
93+
predictions = tf.argmax(y_pred, 1)
94+
8595
new_conf_mtx = tf.math.confusion_matrix(
86-
labels=tf.argmax(y_true, 1),
87-
predictions=tf.argmax(y_pred, 1),
96+
labels=labels,
97+
predictions=predictions,
8898
num_classes=self.num_classes,
8999
weights=sample_weight,
90100
dtype=self.dtype,
@@ -126,7 +136,4 @@ def reset_states(self):
126136
"""Resets all of the metric state variables."""
127137

128138
for v in self.variables:
129-
K.set_value(
130-
v,
131-
np.zeros((self.num_classes, self.num_classes), v.dtype.as_numpy_dtype),
132-
)
139+
K.set_value(v, np.zeros(v.shape, v.dtype.as_numpy_dtype))

tensorflow_addons/metrics/tests/matthews_correlation_coefficient_test.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,30 @@
2323

2424
def test_config():
2525
# mcc object
26-
mcc1 = MatthewsCorrelationCoefficient(num_classes=1)
27-
assert mcc1.num_classes == 1
26+
mcc1 = MatthewsCorrelationCoefficient(num_classes=2)
27+
assert mcc1.num_classes == 2
2828
assert mcc1.dtype == tf.float32
2929
# check configure
3030
mcc2 = MatthewsCorrelationCoefficient.from_config(mcc1.get_config())
31-
assert mcc2.num_classes == 1
31+
assert mcc2.num_classes == 2
3232
assert mcc2.dtype == tf.float32
3333

3434

3535
def check_results(obj, value):
3636
np.testing.assert_allclose(value, obj.result().numpy(), atol=1e-6)
3737

3838

39+
def test_binary_classes_sparse():
40+
gt_label = tf.constant([[1.0], [1.0], [1.0], [0.0]], dtype=tf.float32)
41+
preds = tf.constant([[1.0], [0.0], [1.0], [1.0]], dtype=tf.float32)
42+
# Initialize
43+
mcc = MatthewsCorrelationCoefficient(1)
44+
# Update
45+
mcc.update_state(gt_label, preds)
46+
# Check results
47+
check_results(mcc, [-0.33333334])
48+
49+
3950
def test_binary_classes():
4051
gt_label = tf.constant(
4152
[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], dtype=tf.float32
@@ -91,6 +102,16 @@ def test_multiple_classes():
91102
sklearn_result = sklearn_matthew(gt_label.argmax(axis=1), preds.argmax(axis=1))
92103
check_results(mcc, sklearn_result)
93104

105+
gt_label_sparse = tf.constant(
106+
[[0.0], [2.0], [0.0], [2.0], [1.0], [1.0], [0.0], [0.0], [2.0], [1.0]]
107+
)
108+
preds_sparse = tf.constant(
109+
[[2.0], [0.0], [2.0], [2.0], [2.0], [2.0], [2.0], [0.0], [2.0], [2.0]]
110+
)
111+
mcc = MatthewsCorrelationCoefficient(3)
112+
mcc.update_state(gt_label_sparse, preds_sparse)
113+
check_results(mcc, sklearn_result)
114+
94115

95116
# Keras model API check
96117
def test_keras_model():

0 commit comments

Comments
 (0)