Skip to content

Commit 15d6a1b

Browse files
committed
Test
1 parent a5f1b90 commit 15d6a1b

File tree

4 files changed

+22
-39
lines changed

4 files changed

+22
-39
lines changed

ignite/metrics/mean_average_precision.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Callable, cast, List, Optional, Sequence, Tuple, Union
33

44
import torch
5-
from packaging.version import Version
65
from typing_extensions import Literal
76

87
import ignite.distributed as idist
@@ -12,9 +11,6 @@
1211
from ignite.utils import to_onehot
1312

1413

15-
_torch_version_lt_113 = Version(torch.__version__) < Version("1.13.0")
16-
17-
1814
class _BaseAveragePrecision:
1915
def __init__(
2016
self,
@@ -101,12 +97,9 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens
10197
if self.rec_thresholds is not None:
10298
rec_thresholds = self.rec_thresholds.repeat((*recall.shape[:-1], 1))
10399
rec_thresh_indices = torch.searchsorted(recall, rec_thresholds)
104-
rec_mask = rec_thresh_indices != recall.size(-1)
105-
precision = torch.where(
106-
rec_mask,
107-
precision.take_along_dim(torch.where(rec_mask, rec_thresh_indices, 0), dim=-1),
108-
0.0,
109-
)
100+
precision = precision.take_along_dim(
101+
rec_thresh_indices.where(rec_thresh_indices != recall.size(-1), 0), dim=-1
102+
).where(rec_thresh_indices != recall.size(-1), 0)
110103
recall = rec_thresholds
111104
recall_differential = recall.diff(
112105
dim=-1, prepend=torch.zeros((*recall.shape[:-1], 1), device=recall.device, dtype=recall.dtype)
@@ -342,10 +335,9 @@ def _compute_recall_and_precision(
342335
Returns:
343336
`(recall, precision)`
344337
"""
345-
kwargs = {} if _torch_version_lt_113 else {"stable": True}
346-
indices = torch.argsort(y_pred, descending=True, **kwargs)
338+
indices = torch.argsort(y_pred, stable=True, descending=True)
347339
tp_summation = y_true[indices].cumsum(dim=0)
348-
if tp_summation.device.type != "mps":
340+
if tp_summation.device != torch.device("mps"):
349341
tp_summation = tp_summation.double()
350342

351343
# Adopted from Scikit-learn's implementation
@@ -362,7 +354,7 @@ def _compute_recall_and_precision(
362354
recall = tp_summation / y_true_positive_count
363355

364356
predicted_positive = tp_summation + fp_summation
365-
precision = tp_summation / torch.where(predicted_positive == 0, 1.0, predicted_positive)
357+
precision = tp_summation / torch.where(predicted_positive == 0, 1, predicted_positive)
366358
return recall, precision
367359

368360
def compute(self) -> Union[torch.Tensor, float]:
@@ -379,7 +371,7 @@ def compute(self) -> Union[torch.Tensor, float]:
379371
torch.long if self._type == "multiclass" else torch.uint8,
380372
self._device,
381373
)
382-
fp_precision = torch.double if self._device.type != "mps" else torch.float32
374+
fp_precision = torch.double if self._device != torch.device("mps") else torch.float32
383375
y_pred = _cat_and_agg_tensors(self._y_pred, (num_classes,), fp_precision, self._device)
384376

385377
if self._type == "multiclass":

ignite/metrics/vision/object_detection_average_precision_recall.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
22

33
import torch
4-
from packaging.version import Version
54
from typing_extensions import Literal
65

76
from ignite.metrics import MetricGroup
@@ -10,9 +9,6 @@
109
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
1110

1211

13-
_torch_version_lt_113 = Version(torch.__version__) < Version("1.13.0")
14-
15-
1612
def coco_tensor_list_to_dict_list(
1713
output: Tuple[
1814
Union[List[torch.Tensor], List[Dict[str, torch.Tensor]]],
@@ -217,8 +213,7 @@ def _compute_recall_and_precision(
217213
Returns:
218214
`(recall, precision)`
219215
"""
220-
kwargs = {} if _torch_version_lt_113 else {"stable": True}
221-
indices = torch.argsort(scores, descending=True, **kwargs)
216+
indices = torch.argsort(scores, dim=-1, stable=True, descending=True)
222217
tp = TP[..., indices]
223218
tp_summation = tp.cumsum(dim=-1)
224219
if tp_summation.device.type != "mps":
@@ -231,7 +226,7 @@ def _compute_recall_and_precision(
231226

232227
recall = tp_summation / y_true_count
233228
predicted_positive = tp_summation + fp_summation
234-
precision = tp_summation / torch.where(predicted_positive == 0, 1.0, predicted_positive)
229+
precision = tp_summation / torch.where(predicted_positive == 0, 1, predicted_positive)
235230

236231
return recall, precision
237232

@@ -263,12 +258,9 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens
263258
if recall.size(-1) != 0
264259
else torch.LongTensor([], device=self._device)
265260
)
266-
recall_mask = rec_thresh_indices != recall.size(-1)
267-
precision_integrand = torch.where(
268-
recall_mask,
269-
precision_integrand.take_along_dim(torch.where(recall_mask, rec_thresh_indices, 0), dim=-1),
270-
0.0,
271-
)
261+
precision_integrand = precision_integrand.take_along_dim(
262+
rec_thresh_indices.where(rec_thresh_indices != recall.size(-1), 0), dim=-1
263+
).where(rec_thresh_indices != recall.size(-1), 0)
272264
return torch.sum(precision_integrand, dim=-1) / len(cast(torch.Tensor, self.rec_thresholds))
273265

274266
@reinit__is_reduced
@@ -306,7 +298,6 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor
306298
This key is optional.
307299
========= ================= =================================================
308300
"""
309-
kwargs = {} if _torch_version_lt_113 else {"stable": True}
310301
self._check_matching_input(output)
311302
for pred, target in zip(*output):
312303
labels = target["labels"]
@@ -321,7 +312,7 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor
321312

322313
# Matching logic of object detection mAP, according to COCO reference implementation.
323314
if len(pred["labels"]):
324-
best_detections_index = torch.argsort(pred["scores"], descending=True, **kwargs)
315+
best_detections_index = torch.argsort(pred["scores"], stable=True, descending=True)
325316
max_best_detections_index = torch.cat(
326317
[
327318
best_detections_index[pred["labels"][best_detections_index] == c][

tests/ignite/metrics/test_mean_average_precision.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,30 +45,30 @@ def test__prepare_output():
4545
metric = MeanAveragePrecision()
4646

4747
metric._type = "binary"
48-
scores, y = metric._prepare_output((torch.rand((5, 4, 3, 2)), torch.randint(0, 2, (5, 4, 3, 2))))
48+
scores, y = metric._prepare_output((torch.rand((5, 4, 3, 2)), torch.randint(0, 2, (5, 4, 3, 2)).bool()))
4949
assert scores.shape == y.shape == (1, 120)
5050

5151
metric._type = "multiclass"
5252
scores, y = metric._prepare_output((torch.rand((5, 4, 3, 2)), torch.randint(0, 4, (5, 3, 2))))
5353
assert scores.shape == (4, 30) and y.shape == (30,)
5454

5555
metric._type = "multilabel"
56-
scores, y = metric._prepare_output((torch.rand((5, 4, 3, 2)), torch.randint(0, 2, (5, 4, 3, 2))))
56+
scores, y = metric._prepare_output((torch.rand((5, 4, 3, 2)), torch.randint(0, 2, (5, 4, 3, 2)).bool()))
5757
assert scores.shape == y.shape == (4, 30)
5858

5959

6060
def test_update():
6161
metric = MeanAveragePrecision()
6262
assert len(metric._y_pred) == len(metric._y_true) == 0
63-
metric.update((torch.rand((5, 4)), torch.randint(0, 2, (5, 4))))
63+
metric.update((torch.rand((5, 4)), torch.randint(0, 2, (5, 4)).bool()))
6464
assert len(metric._y_pred) == len(metric._y_true) == 1
6565

6666

6767
def test__compute_recall_and_precision():
6868
m = MeanAveragePrecision()
6969

7070
scores = torch.rand((50,))
71-
y_true = torch.randint(0, 2, (50,))
71+
y_true = torch.randint(0, 2, (50,)).bool()
7272
precision, recall, _ = precision_recall_curve(y_true.numpy(), scores.numpy())
7373
P = y_true.sum(dim=-1)
7474
ignite_recall, ignite_precision = m._compute_recall_and_precision(y_true, scores, P)
@@ -77,7 +77,7 @@ def test__compute_recall_and_precision():
7777

7878
# When there's no actual positive. Numpy expectedly raises warning.
7979
scores = torch.rand((50,))
80-
y_true = torch.zeros((50,))
80+
y_true = torch.zeros((50,)).bool()
8181
precision, recall, _ = precision_recall_curve(y_true.numpy(), scores.numpy())
8282
P = torch.tensor(0)
8383
ignite_recall, ignite_precision = m._compute_recall_and_precision(y_true, scores, P)
@@ -147,7 +147,7 @@ def test_compute_nonbinary_data(class_mean):
147147

148148
# Multilabel
149149
m = MeanAveragePrecision(is_multilabel=True, class_mean=class_mean)
150-
y_true = torch.randint(0, 2, (130, 5, 2, 2))
150+
y_true = torch.randint(0, 2, (130, 5, 2, 2)).bool()
151151
m.update((scores[:50], y_true[:50]))
152152
m.update((scores[50:], y_true[50:]))
153153
ignite_map = m.compute().numpy()

tests/ignite/metrics/vision/test_object_detection_map.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ def test__compute_recall_and_precision(available_device):
872872
def test_compute(sample):
873873
device = idist.device()
874874

875-
if device.type == "mps":
875+
if device == torch.device("mps"):
876876
pytest.skip("Due to MPS backend out of memory")
877877

878878
# [email protected], [email protected], [email protected], AP-S, AP-M, AP-L, AR-1, AR-10, AR-100, AR-S, AR-M, AR-L
@@ -932,7 +932,7 @@ def test_integration(sample):
932932
bs = 3
933933

934934
device = idist.device()
935-
if device.type == "mps":
935+
if device == torch.device("mps"):
936936
pytest.skip("Due to MPS backend out of memory")
937937

938938
def update(engine, i):
@@ -1003,7 +1003,7 @@ def test_distrib_update_compute(distributed, sample):
10031003

10041004
device = idist.device()
10051005

1006-
if device.type == "mps":
1006+
if device == torch.device("mps"):
10071007
pytest.skip("Due to MPS backend out of memory")
10081008

10091009
metric_device = "cpu" if device.type == "xla" else device

0 commit comments

Comments
 (0)