Skip to content

Fix support for sparse labels in MatthewsCorrelationCoefficient #2495

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions tensorflow_addons/metrics/matthews_correlation_coefficient.py
Original file line number Diff line number Diff line change
@@ -69,7 +69,7 @@ def __init__(
):
"""Creates a Matthews Correlation Coefficient instance."""
super().__init__(name=name, dtype=dtype)
self.num_classes = num_classes
self.num_classes = max(2, num_classes)
self.conf_mtx = self.add_weight(
"conf_mtx",
shape=(self.num_classes, self.num_classes),
Comment on lines 69 to 75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of max(2, num_classes), add a check to detect if num_classes is less than 2 and appropriately throw error. See cohen_kappa.py for this.

@@ -82,9 +82,19 @@ def update_state(self, y_true, y_pred, sample_weight=None):
y_true = tf.cast(y_true, dtype=self.dtype)
y_pred = tf.cast(y_pred, dtype=self.dtype)
Comment on lines 82 to 83
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of rounding in the future, I suggest cast it to tf.int64


if y_true.shape[-1] == 1:
labels = tf.squeeze(tf.round(y_true), axis=-1)
else:
labels = tf.argmax(y_true, 1)

if y_pred.shape[-1] == 1:
predictions = tf.squeeze(tf.round(y_pred), axis=-1)
else:
predictions = tf.argmax(y_pred, 1)

new_conf_mtx = tf.math.confusion_matrix(
labels=tf.argmax(y_true, 1),
predictions=tf.argmax(y_pred, 1),
labels=labels,
predictions=predictions,
num_classes=self.num_classes,
weights=sample_weight,
dtype=self.dtype,
@@ -126,7 +136,4 @@ def reset_states(self):
"""Resets all of the metric state variables."""

for v in self.variables:
K.set_value(
v,
np.zeros((self.num_classes, self.num_classes), v.dtype.as_numpy_dtype),
)
K.set_value(v, np.zeros(v.shape, v.dtype.as_numpy_dtype))
Original file line number Diff line number Diff line change
@@ -23,19 +23,30 @@

def test_config():
# mcc object
mcc1 = MatthewsCorrelationCoefficient(num_classes=1)
assert mcc1.num_classes == 1
mcc1 = MatthewsCorrelationCoefficient(num_classes=2)
assert mcc1.num_classes == 2
assert mcc1.dtype == tf.float32
# check configure
mcc2 = MatthewsCorrelationCoefficient.from_config(mcc1.get_config())
assert mcc2.num_classes == 1
assert mcc2.num_classes == 2
assert mcc2.dtype == tf.float32


def check_results(obj, value):
np.testing.assert_allclose(value, obj.result().numpy(), atol=1e-6)


def test_binary_classes_sparse():
gt_label = tf.constant([[1.0], [1.0], [1.0], [0.0]], dtype=tf.float32)
preds = tf.constant([[1.0], [0.0], [1.0], [1.0]], dtype=tf.float32)
# Initialize
mcc = MatthewsCorrelationCoefficient(1)
# Update
mcc.update_state(gt_label, preds)
# Check results
check_results(mcc, [-0.33333334])
Comment on lines +39 to +47
Copy link
Contributor Author

@lgeiger lgeiger Jun 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This restores the behaviour of v0.12 which was removed here

Copy link
Contributor

@jonpsy jonpsy Jun 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure the above example is sparse, can you report the result of the following input?

 gt_label = tf.constant([1, 1, 1, 0], dtype=tf.float32)
 preds = tf.constant([1, 0, 1, 1], dtype=tf.float32)

Copy link
Contributor Author

@lgeiger lgeiger Jun 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I don't think this would work correctly with the current code. But the motivation of this PR was to restore the behaviour of v0.12 and earlier. I'd be happy to change it to support the above example though. Is there a place where the conventions of the metrics are documented? /cc @seanpmorgan

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In tensorflow (and in most libraries), sparse labels are tensors of rank 1. That means it's dimensionality is (numclasses, ) and not (numclasses, 1).

As for conventions, you can check README.md in metrics to get a brief idea on how things work. You can also check metrics.py from tensorflow official for code conventions as we follow the same standard. There's also an exquisite set of metrics available in our metrics folders which will give you a good idea.

Let me know if this helps, 👍



def test_binary_classes():
gt_label = tf.constant(
[[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], dtype=tf.float32
@@ -91,6 +102,16 @@ def test_multiple_classes():
sklearn_result = sklearn_matthew(gt_label.argmax(axis=1), preds.argmax(axis=1))
check_results(mcc, sklearn_result)

gt_label_sparse = tf.constant(
[[0.0], [2.0], [0.0], [2.0], [1.0], [1.0], [0.0], [0.0], [2.0], [1.0]]
)
preds_sparse = tf.constant(
[[2.0], [0.0], [2.0], [2.0], [2.0], [2.0], [2.0], [0.0], [2.0], [2.0]]
)
mcc = MatthewsCorrelationCoefficient(3)
mcc.update_state(gt_label_sparse, preds_sparse)
check_results(mcc, sklearn_result)


# Keras model API check
def test_keras_model():