@@ -61,8 +61,8 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens
61
61
num_classes = 2 if self ._type == "binary" else y_pred .size (1 )
62
62
if self ._type == "multiclass" and y .max () + 1 > num_classes :
63
63
raise ValueError (
64
- f"y_pred contains less classes than y. Number of predicted classes is { num_classes } "
65
- f" and element in y has invalid class = { y .max ().item () + 1 } ."
64
+ f"y_pred contains fewer classes than y. Number of classes in the prediction is { num_classes } "
65
+ f" and an element in y has invalid class = { y .max ().item () + 1 } ."
66
66
)
67
67
y = y .view (- 1 )
68
68
if self ._type == "binary" and self ._average is False :
@@ -86,30 +86,32 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens
86
86
87
87
@reinit__is_reduced
88
88
def reset (self ) -> None :
89
- # `numerator`, `denominator` and `weight` are three variables chosen to be abstract
90
- # representatives of the ones that are measured for cases with different `average` parameters.
91
- # `weight` is only used when `average='weighted'`. Actual value of these three variables is
92
- # as follows.
93
- #
94
- # average='samples':
95
- # numerator (torch.Tensor): sum of metric value for samples
96
- # denominator (int): number of samples
97
- #
98
- # average='weighted':
99
- # numerator (torch.Tensor): number of true positives per class/label
100
- # denominator (torch.Tensor): number of predicted(for precision) or actual(for recall)
101
- # positives per class/label
102
- # weight (torch.Tensor): number of actual positives per class
103
- #
104
- # average='micro':
105
- # numerator (torch.Tensor): sum of number of true positives for classes/labels
106
- # denominator (torch.Tensor): sum of number of predicted(for precision) or actual(for recall) positives
107
- # for classes/labels
108
- #
109
- # average='macro' or boolean or None:
110
- # numerator (torch.Tensor): number of true positives per class/label
111
- # denominator (torch.Tensor): number of predicted(for precision) or actual(for recall)
112
- # positives per class/label
89
+ """
90
+ `numerator`, `denominator` and `weight` are three variables chosen to be abstract
91
+ representatives of the ones that are measured for cases with different `average` parameters.
92
+ `weight` is only used when `average='weighted'`. Actual value of these three variables is
93
+ as follows.
94
+
95
+ average='samples':
96
+ numerator (torch.Tensor): sum of metric value for samples
97
+ denominator (int): number of samples
98
+
99
+ average='weighted':
100
+ numerator (torch.Tensor): number of true positives per class/label
101
+ denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) positives per
102
+ class/label.
103
+ weight (torch.Tensor): number of actual positives per class
104
+
105
+ average='micro':
106
+ numerator (torch.Tensor): sum of number of true positives for classes/labels
107
+ denominator (torch.Tensor): sum of number of predicted(for precision) or actual(for recall) positives for
108
+ classes/labels.
109
+
110
+ average='macro' or boolean or None:
111
+ numerator (torch.Tensor): number of true positives per class/label
112
+ denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) positives per
113
+ class/label.
114
+ """
113
115
114
116
self ._numerator : Union [int , torch .Tensor ] = 0
115
117
self ._denominator : Union [int , torch .Tensor ] = 0
@@ -120,16 +122,20 @@ def reset(self) -> None:
120
122
121
123
@sync_all_reduce ("_numerator" , "_denominator" )
122
124
def compute (self ) -> Union [torch .Tensor , float ]:
123
- # Return value of the metric for `average` options `'weighted'` and `'macro'` is computed as follows.
124
- #
125
- # .. math:: \text{Precision/Recall} = \frac{ numerator }{ denominator } \cdot weight
126
- #
127
- # wherein `weight` is the internal variable `weight` for `'weighted'` option and :math:`1/C`
128
- # for the `macro` one. :math:`C` is the number of classes/labels.
129
- #
130
- # Return value of the metric for `average` options `'micro'`, `'samples'`, `False` and None is as follows.
131
- #
132
- # .. math:: \text{Precision/Recall} = \frac{ numerator }{ denominator }
125
+ r"""
126
+ Return value of the metric for `average` options `'weighted'` and `'macro'` is computed as follows.
127
+
128
+ .. math::
129
+ \text{Precision/Recall} = \frac{ numerator }{ denominator } \cdot weight
130
+
131
+ wherein `weight` is the internal variable `_weight` for `'weighted'` option and :math:`1/C`
132
+ for the `macro` one. :math:`C` is the number of classes/labels.
133
+
134
+ Return value of the metric for `average` options `'micro'`, `'samples'`, `False` and None is as follows.
135
+
136
+ .. math::
137
+ \text{Precision/Recall} = \frac{ numerator }{ denominator }
138
+ """
133
139
134
140
if not self ._updated :
135
141
raise NotComputableError (
@@ -367,6 +373,33 @@ def thresholded_output_transform(output):
367
373
368
374
@reinit__is_reduced
369
375
def update (self , output : Sequence [torch .Tensor ]) -> None :
376
+ r"""
377
+ Update the metric state using prediction and target.
378
+
379
+ Args:
380
+ output: a binary tuple of tensors (y_pred, y) whose shapes follow the table below. N stands for the batch
381
+ dimension, `...` for possible additional dimensions and C for class dimension.
382
+
383
+ .. list-table::
384
+ :widths: 20 10 10 10
385
+ :header-rows: 1
386
+
387
+ * - Output member\\Data type
388
+ - Binary
389
+ - Multiclass
390
+ - Multilabel
391
+ * - y_pred
392
+ - (N, ...)
393
+ - (N, C, ...)
394
+ - (N, C, ...)
395
+ * - y
396
+ - (N, ...)
397
+ - (N, ...)
398
+ - (N, C, ...)
399
+
400
+ For binary and multilabel data, both y and y_pred should consist of 0's and 1's, but for multiclass
401
+ data, y_pred and y should consist of probabilities and integers respectively.
402
+ """
370
403
self ._check_shape (output )
371
404
self ._check_type (output )
372
405
y_pred , y , correct = self ._prepare_output (output )
0 commit comments