@@ -50,13 +50,13 @@ class MatthewsCorrelationCoefficient(tf.keras.metrics.Metric):
50
50
51
51
Usage:
52
52
53
- >>> y_true = np.array([[1.0], [1.0], [1.0], [0.0]], dtype=np.float32)
54
- >>> y_pred = np.array([[1.0], [0 .0], [1.0], [1.0]], dtype=np.float32)
55
- >>> metric = tfa.metrics.MatthewsCorrelationCoefficient(num_classes=1 )
53
+ >>> y_true = np.array([[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], dtype=np.float32)
54
+ >>> y_pred = np.array([[0.0, 1.0], [1.0, 0 .0], [0.0, 1.0], [0.0, 1.0]], dtype=np.float32)
55
+ >>> metric = tfa.metrics.MatthewsCorrelationCoefficient(num_classes=2 )
56
56
>>> metric.update_state(y_true, y_pred)
57
57
>>> result = metric.result()
58
58
>>> result.numpy()
59
- array([ -0.33333334], dtype=float32)
59
+ -0.33333334
60
60
"""
61
61
62
62
@typechecked
@@ -70,28 +70,10 @@ def __init__(
70
70
"""Creates a Matthews Correlation Coefficient instance."""
71
71
super ().__init__ (name = name , dtype = dtype )
72
72
self .num_classes = num_classes
73
- self .true_positives = self .add_weight (
74
- "true_positives" ,
75
- shape = [self .num_classes ],
76
- initializer = "zeros" ,
77
- dtype = self .dtype ,
78
- )
79
- self .false_positives = self .add_weight (
80
- "false_positives" ,
81
- shape = [self .num_classes ],
82
- initializer = "zeros" ,
83
- dtype = self .dtype ,
84
- )
85
- self .false_negatives = self .add_weight (
86
- "false_negatives" ,
87
- shape = [self .num_classes ],
88
- initializer = "zeros" ,
89
- dtype = self .dtype ,
90
- )
91
- self .true_negatives = self .add_weight (
92
- "true_negatives" ,
93
- shape = [self .num_classes ],
94
- initializer = "zeros" ,
73
+ self .conf_mtx = self .add_weight (
74
+ "conf_mtx" ,
75
+ shape = (self .num_classes , self .num_classes ),
76
+ initializer = tf .keras .initializers .zeros ,
95
77
dtype = self .dtype ,
96
78
)
97
79
@@ -100,43 +82,35 @@ def update_state(self, y_true, y_pred, sample_weight=None):
100
82
y_true = tf .cast (y_true , dtype = self .dtype )
101
83
y_pred = tf .cast (y_pred , dtype = self .dtype )
102
84
103
- true_positive = tf .math .count_nonzero ( y_true * y_pred , 0 )
104
- # true_negative
105
- y_true_negative = tf .math . not_equal ( y_true , 1.0 )
106
- y_pred_negative = tf . math . not_equal ( y_pred , 1.0 )
107
- true_negative = tf . math . count_nonzero (
108
- tf . math . logical_and ( y_true_negative , y_pred_negative ), axis = 0
85
+ new_conf_mtx = tf .math .confusion_matrix (
86
+ labels = tf . argmax ( y_true , 1 ),
87
+ predictions = tf .argmax ( y_pred , 1 ),
88
+ num_classes = self . num_classes ,
89
+ weights = sample_weight ,
90
+ dtype = self . dtype ,
109
91
)
110
- # predicted sum
111
- pred_sum = tf .math .count_nonzero (y_pred , 0 )
112
- # Ground truth label sum
113
- true_sum = tf .math .count_nonzero (y_true , 0 )
114
- false_positive = pred_sum - true_positive
115
- false_negative = true_sum - true_positive
116
-
117
- # true positive state_update
118
- self .true_positives .assign_add (tf .cast (true_positive , self .dtype ))
119
- # false positive state_update
120
- self .false_positives .assign_add (tf .cast (false_positive , self .dtype ))
121
- # false negative state_update
122
- self .false_negatives .assign_add (tf .cast (false_negative , self .dtype ))
123
- # true negative state_update
124
- self .true_negatives .assign_add (tf .cast (true_negative , self .dtype ))
92
+
93
+ self .conf_mtx .assign_add (new_conf_mtx )
125
94
126
95
def result (self ):
127
- # numerator
128
- numerator1 = self .true_positives * self .true_negatives
129
- numerator2 = self .false_positives * self .false_negatives
130
- numerator = numerator1 - numerator2
131
- # denominator
132
- denominator1 = self .true_positives + self .false_positives
133
- denominator2 = self .true_positives + self .false_negatives
134
- denominator3 = self .true_negatives + self .false_positives
135
- denominator4 = self .true_negatives + self .false_negatives
136
- denominator = tf .math .sqrt (
137
- denominator1 * denominator2 * denominator3 * denominator4
138
- )
139
- mcc = tf .math .divide_no_nan (numerator , denominator )
96
+
97
+ true_sum = tf .reduce_sum (self .conf_mtx , axis = 1 )
98
+ pred_sum = tf .reduce_sum (self .conf_mtx , axis = 0 )
99
+ num_correct = tf .linalg .trace (self .conf_mtx )
100
+ num_samples = tf .reduce_sum (pred_sum )
101
+
102
+ # covariance true-pred
103
+ cov_ytyp = num_correct * num_samples - tf .tensordot (true_sum , pred_sum , axes = 1 )
104
+ # covariance pred-pred
105
+ cov_ypyp = num_samples ** 2 - tf .tensordot (pred_sum , pred_sum , axes = 1 )
106
+ # covariance true-true
107
+ cov_ytyt = num_samples ** 2 - tf .tensordot (true_sum , true_sum , axes = 1 )
108
+
109
+ mcc = cov_ytyp / tf .math .sqrt (cov_ytyt * cov_ypyp )
110
+
111
+ if tf .math .is_nan (mcc ):
112
+ mcc = tf .constant (0 , dtype = self .dtype )
113
+
140
114
return mcc
141
115
142
116
def get_config (self ):
@@ -150,5 +124,9 @@ def get_config(self):
150
124
151
125
def reset_states (self ):
152
126
"""Resets all of the metric state variables."""
153
- reset_value = np .zeros (self .num_classes , dtype = self .dtype )
154
- K .batch_set_value ([(v , reset_value ) for v in self .variables ])
127
+
128
+ for v in self .variables :
129
+ K .set_value (
130
+ v ,
131
+ np .zeros ((self .num_classes , self .num_classes ), v .dtype .as_numpy_dtype ),
132
+ )
0 commit comments