Skip to content

Commit e83e71c

Browse files
authored
Matthew Fix (#2406)
* initial changes * - doc test changed - multiclass works for OHE - fix multiclass test in unit test * - reset states * Update matthews_correlation_coefficient.py fix syntax * - doc test fixed * compare with sklearn * Update matthews_correlation_coefficient.py conf_matrix use self.dtype
1 parent adfe3ff commit e83e71c

File tree

2 files changed

+89
-77
lines changed

2 files changed

+89
-77
lines changed

tensorflow_addons/metrics/matthews_correlation_coefficient.py

Lines changed: 40 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@ class MatthewsCorrelationCoefficient(tf.keras.metrics.Metric):
5050
5151
Usage:
5252
53-
>>> y_true = np.array([[1.0], [1.0], [1.0], [0.0]], dtype=np.float32)
54-
>>> y_pred = np.array([[1.0], [0.0], [1.0], [1.0]], dtype=np.float32)
55-
>>> metric = tfa.metrics.MatthewsCorrelationCoefficient(num_classes=1)
53+
>>> y_true = np.array([[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], dtype=np.float32)
54+
>>> y_pred = np.array([[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], dtype=np.float32)
55+
>>> metric = tfa.metrics.MatthewsCorrelationCoefficient(num_classes=2)
5656
>>> metric.update_state(y_true, y_pred)
5757
>>> result = metric.result()
5858
>>> result.numpy()
59-
array([-0.33333334], dtype=float32)
59+
-0.33333334
6060
"""
6161

6262
@typechecked
@@ -70,28 +70,10 @@ def __init__(
7070
"""Creates a Matthews Correlation Coefficient instance."""
7171
super().__init__(name=name, dtype=dtype)
7272
self.num_classes = num_classes
73-
self.true_positives = self.add_weight(
74-
"true_positives",
75-
shape=[self.num_classes],
76-
initializer="zeros",
77-
dtype=self.dtype,
78-
)
79-
self.false_positives = self.add_weight(
80-
"false_positives",
81-
shape=[self.num_classes],
82-
initializer="zeros",
83-
dtype=self.dtype,
84-
)
85-
self.false_negatives = self.add_weight(
86-
"false_negatives",
87-
shape=[self.num_classes],
88-
initializer="zeros",
89-
dtype=self.dtype,
90-
)
91-
self.true_negatives = self.add_weight(
92-
"true_negatives",
93-
shape=[self.num_classes],
94-
initializer="zeros",
73+
self.conf_mtx = self.add_weight(
74+
"conf_mtx",
75+
shape=(self.num_classes, self.num_classes),
76+
initializer=tf.keras.initializers.zeros,
9577
dtype=self.dtype,
9678
)
9779

@@ -100,43 +82,35 @@ def update_state(self, y_true, y_pred, sample_weight=None):
10082
y_true = tf.cast(y_true, dtype=self.dtype)
10183
y_pred = tf.cast(y_pred, dtype=self.dtype)
10284

103-
true_positive = tf.math.count_nonzero(y_true * y_pred, 0)
104-
# true_negative
105-
y_true_negative = tf.math.not_equal(y_true, 1.0)
106-
y_pred_negative = tf.math.not_equal(y_pred, 1.0)
107-
true_negative = tf.math.count_nonzero(
108-
tf.math.logical_and(y_true_negative, y_pred_negative), axis=0
85+
new_conf_mtx = tf.math.confusion_matrix(
86+
labels=tf.argmax(y_true, 1),
87+
predictions=tf.argmax(y_pred, 1),
88+
num_classes=self.num_classes,
89+
weights=sample_weight,
90+
dtype=self.dtype,
10991
)
110-
# predicted sum
111-
pred_sum = tf.math.count_nonzero(y_pred, 0)
112-
# Ground truth label sum
113-
true_sum = tf.math.count_nonzero(y_true, 0)
114-
false_positive = pred_sum - true_positive
115-
false_negative = true_sum - true_positive
116-
117-
# true positive state_update
118-
self.true_positives.assign_add(tf.cast(true_positive, self.dtype))
119-
# false positive state_update
120-
self.false_positives.assign_add(tf.cast(false_positive, self.dtype))
121-
# false negative state_update
122-
self.false_negatives.assign_add(tf.cast(false_negative, self.dtype))
123-
# true negative state_update
124-
self.true_negatives.assign_add(tf.cast(true_negative, self.dtype))
92+
93+
self.conf_mtx.assign_add(new_conf_mtx)
12594

12695
def result(self):
127-
# numerator
128-
numerator1 = self.true_positives * self.true_negatives
129-
numerator2 = self.false_positives * self.false_negatives
130-
numerator = numerator1 - numerator2
131-
# denominator
132-
denominator1 = self.true_positives + self.false_positives
133-
denominator2 = self.true_positives + self.false_negatives
134-
denominator3 = self.true_negatives + self.false_positives
135-
denominator4 = self.true_negatives + self.false_negatives
136-
denominator = tf.math.sqrt(
137-
denominator1 * denominator2 * denominator3 * denominator4
138-
)
139-
mcc = tf.math.divide_no_nan(numerator, denominator)
96+
97+
true_sum = tf.reduce_sum(self.conf_mtx, axis=1)
98+
pred_sum = tf.reduce_sum(self.conf_mtx, axis=0)
99+
num_correct = tf.linalg.trace(self.conf_mtx)
100+
num_samples = tf.reduce_sum(pred_sum)
101+
102+
# covariance true-pred
103+
cov_ytyp = num_correct * num_samples - tf.tensordot(true_sum, pred_sum, axes=1)
104+
# covariance pred-pred
105+
cov_ypyp = num_samples ** 2 - tf.tensordot(pred_sum, pred_sum, axes=1)
106+
# covariance true-true
107+
cov_ytyt = num_samples ** 2 - tf.tensordot(true_sum, true_sum, axes=1)
108+
109+
mcc = cov_ytyp / tf.math.sqrt(cov_ytyt * cov_ypyp)
110+
111+
if tf.math.is_nan(mcc):
112+
mcc = tf.constant(0, dtype=self.dtype)
113+
140114
return mcc
141115

142116
def get_config(self):
@@ -150,5 +124,9 @@ def get_config(self):
150124

151125
def reset_states(self):
152126
"""Resets all of the metric state variables."""
153-
reset_value = np.zeros(self.num_classes, dtype=self.dtype)
154-
K.batch_set_value([(v, reset_value) for v in self.variables])
127+
128+
for v in self.variables:
129+
K.set_value(
130+
v,
131+
np.zeros((self.num_classes, self.num_classes), v.dtype.as_numpy_dtype),
132+
)

tensorflow_addons/metrics/tests/matthews_correlation_coefficient_test.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import numpy as np
2020
from tensorflow_addons.metrics import MatthewsCorrelationCoefficient
21+
from sklearn.metrics import matthews_corrcoef as sklearn_matthew
2122

2223

2324
def test_config():
@@ -36,30 +37,59 @@ def check_results(obj, value):
3637

3738

3839
def test_binary_classes():
39-
gt_label = tf.constant([[1.0], [1.0], [1.0], [0.0]], dtype=tf.float32)
40-
preds = tf.constant([[1.0], [0.0], [1.0], [1.0]], dtype=tf.float32)
40+
gt_label = tf.constant(
41+
[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], dtype=tf.float32
42+
)
43+
preds = tf.constant(
44+
[[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], dtype=tf.float32
45+
)
4146
# Initialize
42-
mcc = MatthewsCorrelationCoefficient(1)
47+
mcc = MatthewsCorrelationCoefficient(2)
4348
# Update
4449
mcc.update_state(gt_label, preds)
4550
# Check results
4651
check_results(mcc, [-0.33333334])
4752

4853

54+
# See issue #2339
4955
def test_multiple_classes():
50-
gt_label = tf.constant(
51-
[[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 1.0]],
52-
dtype=tf.float32,
56+
gt_label = np.array(
57+
[
58+
[1.0, 0.0, 0.0],
59+
[0.0, 0.0, 1.0],
60+
[1.0, 0.0, 0.0],
61+
[0.0, 0.0, 1.0],
62+
[0.0, 1.0, 0.0],
63+
[0.0, 1.0, 0.0],
64+
[1.0, 0.0, 0.0],
65+
[1.0, 0.0, 0.0],
66+
[0.0, 0.0, 1.0],
67+
[0.0, 1.0, 0.0],
68+
]
5369
)
54-
preds = tf.constant(
55-
[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0]],
56-
dtype=tf.float32,
70+
preds = np.array(
71+
[
72+
[0.0, 0.0, 1.0],
73+
[1.0, 0.0, 0.0],
74+
[0.0, 0.0, 1.0],
75+
[0.0, 0.0, 1.0],
76+
[0.0, 0.0, 1.0],
77+
[0.0, 0.0, 1.0],
78+
[0.0, 0.0, 1.0],
79+
[1.0, 0.0, 0.0],
80+
[0.0, 0.0, 1.0],
81+
[0.0, 0.0, 1.0],
82+
]
5783
)
84+
tensor_gt_label = tf.constant(gt_label, dtype=tf.float32)
85+
tensor_preds = tf.constant(preds, dtype=tf.float32)
5886
# Initialize
5987
mcc = MatthewsCorrelationCoefficient(3)
60-
mcc.update_state(gt_label, preds)
61-
# Check results
62-
check_results(mcc, [-0.33333334, 1.0, 0.57735026])
88+
# Update
89+
mcc.update_state(tensor_gt_label, tensor_preds)
90+
# Check results by comparing to results of scikit-learn matthew implementation.
91+
sklearn_result = sklearn_matthew(gt_label.argmax(axis=1), preds.argmax(axis=1))
92+
check_results(mcc, sklearn_result)
6393

6494

6595
# Keras model API check
@@ -80,9 +110,13 @@ def test_keras_model():
80110

81111

82112
def test_reset_states_graph():
83-
gt_label = tf.constant([[1.0], [1.0], [1.0], [0.0]], dtype=tf.float32)
84-
preds = tf.constant([[1.0], [0.0], [1.0], [1.0]], dtype=tf.float32)
85-
mcc = MatthewsCorrelationCoefficient(1)
113+
gt_label = tf.constant(
114+
[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], dtype=tf.float32
115+
)
116+
preds = tf.constant(
117+
[[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], dtype=tf.float32
118+
)
119+
mcc = MatthewsCorrelationCoefficient(2)
86120
mcc.update_state(gt_label, preds)
87121

88122
@tf.function

0 commit comments

Comments
 (0)