Skip to content

Commit 8494029

Browse files
authored
FROC metric in ND (#6528)
Fixes #5172 . ### Description Implementation of FROC metric in ND ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [X] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [X] In-line docstrings updated. --------- Signed-off-by: Tomasz Bartczak <[email protected]>
1 parent b60f69e commit 8494029

File tree

3 files changed

+93
-19
lines changed

3 files changed

+93
-19
lines changed

monai/metrics/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix
1616
from .cumulative_average import CumulativeAverage
1717
from .f_beta_score import FBetaScore
18-
from .froc import compute_fp_tp_probs, compute_froc_curve_data, compute_froc_score
18+
from .froc import compute_fp_tp_probs, compute_fp_tp_probs_nd, compute_froc_curve_data, compute_froc_score
1919
from .generalized_dice import GeneralizedDiceScore, compute_generalized_dice
2020
from .hausdorff_distance import HausdorffDistanceMetric, compute_hausdorff_distance, compute_percent_hausdorff_distance
2121
from .loss_metric import LossMetric

monai/metrics/froc.py

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,11 @@
1919
from monai.config import NdarrayOrTensor
2020

2121

22-
def compute_fp_tp_probs(
22+
def compute_fp_tp_probs_nd(
2323
probs: NdarrayOrTensor,
24-
y_coord: NdarrayOrTensor,
25-
x_coord: NdarrayOrTensor,
24+
coords: NdarrayOrTensor,
2625
evaluation_mask: NdarrayOrTensor,
2726
labels_to_exclude: list | None = None,
28-
resolution_level: int = 0,
2927
) -> tuple[NdarrayOrTensor, NdarrayOrTensor, int]:
3028
"""
3129
This function is modified from the official evaluation code of
@@ -36,29 +34,28 @@ def compute_fp_tp_probs(
3634
Args:
3735
probs: an array with shape (n,) that represents the probabilities of the detections.
3836
Where, n is the number of predicted detections.
39-
y_coord: an array with shape (n,) that represents the Y-coordinates of the detections.
40-
x_coord: an array with shape (n,) that represents the X-coordinates of the detections.
37+
coords: an array with shape (n, n_dim) that represents the coordinates of the detections.
38+
The dimensions must be in the same order as in `evaluation_mask`.
4139
evaluation_mask: the ground truth mask for evaluation.
4240
labels_to_exclude: labels in this list will not be counted for metric calculation.
43-
resolution_level: the level at which the evaluation mask is made.
4441
4542
Returns:
4643
fp_probs: an array that contains the probabilities of the false positive detections.
4744
tp_probs: an array that contains the probabilities of the True positive detections.
4845
num_targets: the total number of targets (excluding `labels_to_exclude`) for all images under evaluation.
4946
5047
"""
51-
if not (probs.shape == y_coord.shape == x_coord.shape):
48+
if not (len(probs) == len(coords)):
49+
raise ValueError(f"the length of probs {probs.shape}, should be the same as of coords {coords.shape}.")
50+
if not (len(coords.shape) > 1 and coords.shape[1] == len(evaluation_mask.shape)):
5251
raise ValueError(
53-
f"the shapes between probs {probs.shape}, y_coord {y_coord.shape} and x_coord {x_coord.shape} should be the same."
52+
f"coords {coords.shape} need to represent the same number of dimensions as mask {evaluation_mask.shape}."
5453
)
5554

5655
if isinstance(probs, torch.Tensor):
5756
probs = probs.detach().cpu().numpy()
58-
if isinstance(y_coord, torch.Tensor):
59-
y_coord = y_coord.detach().cpu().numpy()
60-
if isinstance(x_coord, torch.Tensor):
61-
x_coord = x_coord.detach().cpu().numpy()
57+
if isinstance(coords, torch.Tensor):
58+
coords = coords.detach().cpu().numpy()
6259
if isinstance(evaluation_mask, torch.Tensor):
6360
evaluation_mask = evaluation_mask.detach().cpu().numpy()
6461

@@ -68,10 +65,7 @@ def compute_fp_tp_probs(
6865
max_label = np.max(evaluation_mask)
6966
tp_probs = np.zeros((max_label,), dtype=np.float32)
7067

71-
y_coord = (y_coord / pow(2, resolution_level)).astype(int)
72-
x_coord = (x_coord / pow(2, resolution_level)).astype(int)
73-
74-
hittedlabel = evaluation_mask[y_coord, x_coord]
68+
hittedlabel = evaluation_mask[tuple(coords.T)]
7569
fp_probs = probs[np.where(hittedlabel == 0)]
7670
for i in range(1, max_label + 1):
7771
if i not in labels_to_exclude and i in hittedlabel:
@@ -81,6 +75,50 @@ def compute_fp_tp_probs(
8175
return fp_probs, tp_probs, cast(int, num_targets)
8276

8377

78+
def compute_fp_tp_probs(
79+
probs: NdarrayOrTensor,
80+
y_coord: NdarrayOrTensor,
81+
x_coord: NdarrayOrTensor,
82+
evaluation_mask: NdarrayOrTensor,
83+
labels_to_exclude: list | None = None,
84+
resolution_level: int = 0,
85+
) -> tuple[NdarrayOrTensor, NdarrayOrTensor, int]:
86+
"""
87+
This function is modified from the official evaluation code of
88+
`CAMELYON 16 Challenge <https://camelyon16.grand-challenge.org/>`_, and used to distinguish
89+
true positive and false positive predictions. A true positive prediction is defined when
90+
the detection point is within the annotated ground truth region.
91+
92+
Args:
93+
probs: an array with shape (n,) that represents the probabilities of the detections.
94+
Where, n is the number of predicted detections.
95+
y_coord: an array with shape (n,) that represents the Y-coordinates of the detections.
96+
x_coord: an array with shape (n,) that represents the X-coordinates of the detections.
97+
evaluation_mask: the ground truth mask for evaluation.
98+
labels_to_exclude: labels in this list will not be counted for metric calculation.
99+
resolution_level: the level at which the evaluation mask is made.
100+
101+
Returns:
102+
fp_probs: an array that contains the probabilities of the false positive detections.
103+
tp_probs: an array that contains the probabilities of the True positive detections.
104+
num_targets: the total number of targets (excluding `labels_to_exclude`) for all images under evaluation.
105+
106+
"""
107+
if isinstance(y_coord, torch.Tensor):
108+
y_coord = y_coord.detach().cpu().numpy()
109+
if isinstance(x_coord, torch.Tensor):
110+
x_coord = x_coord.detach().cpu().numpy()
111+
112+
y_coord = (y_coord / pow(2, resolution_level)).astype(int)
113+
x_coord = (x_coord / pow(2, resolution_level)).astype(int)
114+
115+
stacked = np.stack([y_coord, x_coord], axis=1)
116+
117+
return compute_fp_tp_probs_nd(
118+
probs=probs, coords=stacked, evaluation_mask=evaluation_mask, labels_to_exclude=labels_to_exclude
119+
)
120+
121+
84122
def compute_froc_curve_data(
85123
fp_probs: np.ndarray | torch.Tensor, tp_probs: np.ndarray | torch.Tensor, num_targets: int, num_images: int
86124
) -> tuple[np.ndarray, np.ndarray]:

tests/test_compute_froc.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818
from parameterized import parameterized
1919

20-
from monai.metrics import compute_fp_tp_probs, compute_froc_curve_data, compute_froc_score
20+
from monai.metrics import compute_fp_tp_probs, compute_fp_tp_probs_nd, compute_froc_curve_data, compute_froc_score
2121

2222
_device = "cuda:0" if torch.cuda.is_available() else "cpu"
2323
TEST_CASE_1 = [
@@ -82,6 +82,33 @@
8282
0.75,
8383
]
8484

85+
TEST_CASE_ND_1 = [
86+
{
87+
"probs": torch.tensor([1, 0.6, 0.8]),
88+
"coords": torch.tensor([[0, 3], [2, 0], [3, 1]]),
89+
"evaluation_mask": np.array([[0, 0, 1, 1], [2, 2, 0, 0], [0, 3, 3, 0], [0, 3, 3, 3]]),
90+
},
91+
np.array([0.6]),
92+
np.array([1, 0, 0.8]),
93+
3,
94+
]
95+
96+
TEST_CASE_ND_2 = [
97+
{
98+
"probs": torch.tensor([1, 0.6, 0.8]),
99+
"coords": torch.tensor([[0, 0, 3], [1, 2, 0], [0, 3, 1]]),
100+
"evaluation_mask": np.array(
101+
[
102+
[[0, 0, 1, 1], [2, 2, 0, 0], [0, 3, 3, 0], [0, 3, 3, 3]],
103+
[[0, 0, 1, 1], [2, 2, 0, 0], [0, 3, 3, 0], [0, 3, 3, 3]],
104+
]
105+
),
106+
},
107+
np.array([0.6]),
108+
np.array([1, 0, 0.8]),
109+
3,
110+
]
111+
85112

86113
class TestComputeFpTp(unittest.TestCase):
87114
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
@@ -92,6 +119,15 @@ def test_value(self, input_data, expected_fp, expected_tp, expected_num):
92119
np.testing.assert_equal(num_tumors, expected_num)
93120

94121

122+
class TestComputeFpTpNd(unittest.TestCase):
123+
@parameterized.expand([TEST_CASE_ND_1, TEST_CASE_ND_2])
124+
def test_value(self, input_data, expected_fp, expected_tp, expected_num):
125+
fp_probs, tp_probs, num_tumors = compute_fp_tp_probs_nd(**input_data)
126+
np.testing.assert_allclose(fp_probs, expected_fp, rtol=1e-5)
127+
np.testing.assert_allclose(tp_probs, expected_tp, rtol=1e-5)
128+
np.testing.assert_equal(num_tumors, expected_num)
129+
130+
95131
class TestComputeFrocScore(unittest.TestCase):
96132
@parameterized.expand([TEST_CASE_4, TEST_CASE_5])
97133
def test_value(self, input_data, thresholds, expected_score):

0 commit comments

Comments
 (0)