-
Notifications
You must be signed in to change notification settings - Fork 615
Fix support for sparse labels in MatthewsCorrelationCoefficient #2495
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -69,7 +69,7 @@ def __init__( | |
): | ||
"""Creates a Matthews Correlation Coefficient instance.""" | ||
super().__init__(name=name, dtype=dtype) | ||
self.num_classes = num_classes | ||
self.num_classes = max(2, num_classes) | ||
self.conf_mtx = self.add_weight( | ||
"conf_mtx", | ||
shape=(self.num_classes, self.num_classes), | ||
|
@@ -82,9 +82,19 @@ def update_state(self, y_true, y_pred, sample_weight=None): | |
y_true = tf.cast(y_true, dtype=self.dtype) | ||
y_pred = tf.cast(y_pred, dtype=self.dtype) | ||
Comment on lines
82
to
83
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of rounding in the future, I suggest cast it to tf.int64 |
||
|
||
if y_true.shape[-1] == 1: | ||
labels = tf.squeeze(tf.round(y_true), axis=-1) | ||
else: | ||
labels = tf.argmax(y_true, 1) | ||
|
||
if y_pred.shape[-1] == 1: | ||
predictions = tf.squeeze(tf.round(y_pred), axis=-1) | ||
else: | ||
predictions = tf.argmax(y_pred, 1) | ||
|
||
new_conf_mtx = tf.math.confusion_matrix( | ||
labels=tf.argmax(y_true, 1), | ||
predictions=tf.argmax(y_pred, 1), | ||
labels=labels, | ||
predictions=predictions, | ||
num_classes=self.num_classes, | ||
weights=sample_weight, | ||
dtype=self.dtype, | ||
|
@@ -126,7 +136,4 @@ def reset_states(self): | |
"""Resets all of the metric state variables.""" | ||
|
||
for v in self.variables: | ||
K.set_value( | ||
v, | ||
np.zeros((self.num_classes, self.num_classes), v.dtype.as_numpy_dtype), | ||
) | ||
K.set_value(v, np.zeros(v.shape, v.dtype.as_numpy_dtype)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,19 +23,30 @@ | |
|
||
def test_config(): | ||
# mcc object | ||
mcc1 = MatthewsCorrelationCoefficient(num_classes=1) | ||
assert mcc1.num_classes == 1 | ||
mcc1 = MatthewsCorrelationCoefficient(num_classes=2) | ||
assert mcc1.num_classes == 2 | ||
assert mcc1.dtype == tf.float32 | ||
# check configure | ||
mcc2 = MatthewsCorrelationCoefficient.from_config(mcc1.get_config()) | ||
assert mcc2.num_classes == 1 | ||
assert mcc2.num_classes == 2 | ||
assert mcc2.dtype == tf.float32 | ||
|
||
|
||
def check_results(obj, value): | ||
np.testing.assert_allclose(value, obj.result().numpy(), atol=1e-6) | ||
|
||
|
||
def test_binary_classes_sparse(): | ||
gt_label = tf.constant([[1.0], [1.0], [1.0], [0.0]], dtype=tf.float32) | ||
preds = tf.constant([[1.0], [0.0], [1.0], [1.0]], dtype=tf.float32) | ||
# Initialize | ||
mcc = MatthewsCorrelationCoefficient(1) | ||
# Update | ||
mcc.update_state(gt_label, preds) | ||
# Check results | ||
check_results(mcc, [-0.33333334]) | ||
Comment on lines
+39
to
+47
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This restores the behaviour of v0.12 which was removed here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure the above example is sparse, can you report the result of the following input? gt_label = tf.constant([1, 1, 1, 0], dtype=tf.float32)
preds = tf.constant([1, 0, 1, 1], dtype=tf.float32) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I don't think this would work correctly with the current code. But the motivation of this PR was to restore the behaviour of v0.12 and earlier. I'd be happy to change it to support the above example though. Is there a place where the conventions of the metrics are documented? /cc @seanpmorgan There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In tensorflow (and in most libraries), sparse labels are tensors of rank 1. That means it's dimensionality is As for conventions, you can check README.md in metrics to get a brief idea on how things work. You can also check metrics.py from tensorflow official for code conventions as we follow the same standard. There's also an exquisite set of metrics available in our Let me know if this helps, 👍 |
||
|
||
|
||
def test_binary_classes(): | ||
gt_label = tf.constant( | ||
[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], dtype=tf.float32 | ||
|
@@ -91,6 +102,16 @@ def test_multiple_classes(): | |
sklearn_result = sklearn_matthew(gt_label.argmax(axis=1), preds.argmax(axis=1)) | ||
check_results(mcc, sklearn_result) | ||
|
||
gt_label_sparse = tf.constant( | ||
[[0.0], [2.0], [0.0], [2.0], [1.0], [1.0], [0.0], [0.0], [2.0], [1.0]] | ||
) | ||
preds_sparse = tf.constant( | ||
[[2.0], [0.0], [2.0], [2.0], [2.0], [2.0], [2.0], [0.0], [2.0], [2.0]] | ||
) | ||
mcc = MatthewsCorrelationCoefficient(3) | ||
mcc.update_state(gt_label_sparse, preds_sparse) | ||
check_results(mcc, sklearn_result) | ||
|
||
|
||
# Keras model API check | ||
def test_keras_model(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of
max(2, num_classes)
, add a check to detect if num_classes is less than 2 and appropriately throw error. See cohen_kappa.py for this.