-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Add CalibrationErrorMetric and CalibrationError handler #8707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
theo-barfoot
wants to merge
2
commits into
Project-MONAI:dev
Choose a base branch
from
theo-barfoot:feature/calibration-metrics
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+874
−0
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Callable | ||
|
|
||
| from monai.handlers.ignite_metric import IgniteMetricHandler | ||
| from monai.metrics import CalibrationErrorMetric, CalibrationReduction | ||
| from monai.utils import MetricReduction | ||
|
|
||
| __all__ = ["CalibrationError"] | ||
|
|
||
|
|
||
| class CalibrationError(IgniteMetricHandler): | ||
| """ | ||
| Computes Calibration Error and reports the aggregated value according to `metric_reduction` | ||
| over all accumulated iterations. Can return the expected, average, or maximum calibration error. | ||
|
|
||
| Args: | ||
| num_bins: number of bins to calculate calibration. Defaults to 20. | ||
| include_background: whether to include calibration error computation on the first channel of | ||
| the predicted output. Defaults to True. | ||
| calibration_reduction: Method for calculating calibration error values from binned data. | ||
| Available modes are `"expected"`, `"average"`, and `"maximum"`. Defaults to `"expected"`. | ||
| metric_reduction: Mode of reduction to apply to the metrics. | ||
| Reduction is only applied to non-NaN values. | ||
| Available reduction modes are `"none"`, `"mean"`, `"sum"`, `"mean_batch"`, | ||
| `"sum_batch"`, `"mean_channel"`, and `"sum_channel"`. | ||
| Defaults to `"mean"`. If set to `"none"`, no reduction will be performed. | ||
| output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then | ||
| construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or | ||
| lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. | ||
| `engine.state` and `output_transform` inherit from the ignite concept: | ||
| https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: | ||
| https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. | ||
| save_details: whether to save metric computation details per image, for example: calibration error | ||
| of every image. default to True, will save to `engine.state.metric_details` dict with the | ||
| metric name as key. | ||
|
|
||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| num_bins: int = 20, | ||
| include_background: bool = True, | ||
| calibration_reduction: CalibrationReduction | str = CalibrationReduction.EXPECTED, | ||
| metric_reduction: MetricReduction | str = MetricReduction.MEAN, | ||
| output_transform: Callable = lambda x: x, | ||
| save_details: bool = True, | ||
| ) -> None: | ||
| metric_fn = CalibrationErrorMetric( | ||
| num_bins=num_bins, | ||
| include_background=include_background, | ||
| calibration_reduction=calibration_reduction, | ||
| metric_reduction=metric_reduction, | ||
| ) | ||
|
|
||
| super().__init__( | ||
| metric_fn=metric_fn, | ||
| output_transform=output_transform, | ||
| save_details=save_details, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,260 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Any | ||
|
|
||
| import torch | ||
|
|
||
| from monai.metrics.metric import CumulativeIterationMetric | ||
| from monai.metrics.utils import do_metric_reduction, ignore_background | ||
| from monai.utils import MetricReduction | ||
| from monai.utils.enums import StrEnum | ||
|
|
||
| __all__ = [ | ||
| "calibration_binning", | ||
| "CalibrationErrorMetric", | ||
| "CalibrationReduction", | ||
| ] | ||
|
|
||
|
|
||
| def calibration_binning( | ||
| y_pred: torch.Tensor, y: torch.Tensor, num_bins: int = 20, right: bool = False | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Compute calibration bins for predicted probabilities and ground truth labels. | ||
| This function calculates the mean predicted probabilities, mean ground truths, | ||
| and bin counts for each bin using a hard binning calibration approach. | ||
|
|
||
| The function operates on input and target tensors with batch and channel dimensions, | ||
| handling each batch and channel separately. For bins that do not contain any elements, | ||
| the mean predicted values and mean ground truth values are set to NaN. | ||
|
|
||
| Args: | ||
| y_pred: predicted tensor with shape [batch, channel, spatial], where spatial | ||
| can be any number of dimensions. The y_pred tensor represents probabilities. | ||
| Values should be in the range [0, 1] (probabilities). | ||
| y: Target tensor with the same shape as y_pred. It represents ground truth values. | ||
| num_bins: The number of bins to use for calibration. Defaults to 20. Must be >= 1. | ||
| right: If False (default), the bins include the left boundary and exclude the right boundary. | ||
| If True, the bins exclude the left boundary and include the right boundary. | ||
|
|
||
| Returns: | ||
| A tuple of three tensors: | ||
| - mean_p_per_bin: Tensor of shape [batch_size, num_channels, num_bins] containing | ||
| the mean predicted values in each bin. | ||
| - mean_gt_per_bin: Tensor of shape [batch_size, num_channels, num_bins] containing | ||
| the mean ground truth values in each bin. | ||
| - bin_counts: Tensor of shape [batch_size, num_channels, num_bins] containing | ||
| the count of elements in each bin. | ||
|
|
||
| Raises: | ||
| ValueError: If the input and target shapes do not match, if the input has fewer than 3 dimensions, | ||
| or if num_bins < 1. | ||
|
|
||
| Note: | ||
| This function currently uses nested for loops over batch and channel dimensions | ||
| for binning operations. Future improvements may include vectorizing these operations | ||
| for enhanced performance. | ||
| """ | ||
| # Input validation | ||
| if y_pred.shape != y.shape: | ||
| raise ValueError(f"y_pred and y must have the same shape, got {y_pred.shape} and {y.shape}.") | ||
| if y_pred.ndim < 3: | ||
| raise ValueError(f"y_pred must have shape (B, C, spatial...), got ndim={y_pred.ndim}.") | ||
| if num_bins < 1: | ||
| raise ValueError(f"num_bins must be >= 1, got {num_bins}.") | ||
|
|
||
| batch_size, num_channels = y_pred.shape[:2] | ||
| boundaries = torch.linspace( | ||
| start=0.0, | ||
| end=1.0 + torch.finfo(torch.float32).eps, | ||
| steps=num_bins + 1, | ||
| device=y_pred.device, | ||
| ) | ||
|
|
||
| mean_p_per_bin = torch.zeros(batch_size, num_channels, num_bins, device=y_pred.device) | ||
| mean_gt_per_bin = torch.zeros_like(mean_p_per_bin) | ||
| bin_counts = torch.zeros_like(mean_p_per_bin) | ||
|
|
||
| y_pred_flat = y_pred.flatten(start_dim=2).float() | ||
| y_flat = y.flatten(start_dim=2).float() | ||
|
|
||
| for b in range(batch_size): | ||
| for c in range(num_channels): | ||
| values_p = y_pred_flat[b, c, :] | ||
| values_gt = y_flat[b, c, :] | ||
|
|
||
| # Compute bin indices and clamp to valid range to handle out-of-range values | ||
| bin_idx = torch.bucketize(values_p, boundaries[1:], right=right) | ||
| bin_idx = bin_idx.clamp(max=num_bins - 1) | ||
|
|
||
| # Compute bin counts using scatter_add | ||
| counts = torch.zeros(num_bins, device=y_pred.device, dtype=torch.float32) | ||
| counts.scatter_add_(0, bin_idx, torch.ones_like(values_p)) | ||
| bin_counts[b, c, :] = counts | ||
|
|
||
| # Compute sums for mean calculation using scatter_add (more compatible than scatter_reduce) | ||
| sum_p = torch.zeros(num_bins, device=y_pred.device, dtype=torch.float32) | ||
| sum_p.scatter_add_(0, bin_idx, values_p) | ||
|
|
||
| sum_gt = torch.zeros(num_bins, device=y_pred.device, dtype=torch.float32) | ||
| sum_gt.scatter_add_(0, bin_idx, values_gt) | ||
|
|
||
| # Compute means, avoiding division by zero | ||
| safe_counts = counts.clamp(min=1) | ||
| mean_p_per_bin[b, c, :] = sum_p / safe_counts | ||
| mean_gt_per_bin[b, c, :] = sum_gt / safe_counts | ||
|
|
||
| # Set empty bins to NaN | ||
| mean_p_per_bin[bin_counts == 0] = torch.nan | ||
| mean_gt_per_bin[bin_counts == 0] = torch.nan | ||
|
|
||
| return mean_p_per_bin, mean_gt_per_bin, bin_counts | ||
|
|
||
|
|
||
| class CalibrationReduction(StrEnum): | ||
| """ | ||
| Enumeration of calibration error reduction methods. | ||
|
|
||
| - EXPECTED: Expected Calibration Error (ECE) - weighted average by bin count | ||
| - AVERAGE: Average Calibration Error (ACE) - simple average across bins | ||
| - MAXIMUM: Maximum Calibration Error (MCE) - maximum error across bins | ||
| """ | ||
|
|
||
| EXPECTED = "expected" | ||
| AVERAGE = "average" | ||
| MAXIMUM = "maximum" | ||
|
|
||
|
|
||
| class CalibrationErrorMetric(CumulativeIterationMetric): | ||
| """ | ||
| Compute the Calibration Error between predicted probabilities and ground truth labels. | ||
| This metric is suitable for multi-class tasks and supports batched inputs. | ||
|
|
||
| The input `y_pred` represents the model's predicted probabilities, and `y` represents the ground truth labels. | ||
| `y_pred` is expected to have probabilities, and `y` should be in one-hot format. You can use suitable transforms | ||
| in `monai.transforms.post` to achieve the desired format. | ||
|
|
||
| The `include_background` parameter can be set to `False` to exclude the first category (channel index 0), | ||
| which is conventionally assumed to be the background. This is particularly useful in segmentation tasks where | ||
| the background class might skew the calibration results. | ||
|
|
||
| The metric supports both single-channel and multi-channel data. For multi-channel data, the input tensors | ||
| should be in the format of BCHW[D], where B is the batch size, C is the number of channels, and HW[D] | ||
| are the spatial dimensions. | ||
|
|
||
| Args: | ||
| num_bins: Number of bins to divide probabilities into for calibration calculation. Defaults to 20. | ||
| include_background: Whether to include computation on the first channel of the predicted output. | ||
| Defaults to `True`. | ||
| calibration_reduction: Method for calculating calibration error values from binned data. | ||
| Available modes are `"expected"`, `"average"`, and `"maximum"`. Defaults to `"expected"`. | ||
| metric_reduction: Mode of reduction to apply to the metrics. | ||
| Reduction is only applied to non-NaN values. | ||
| Available reduction modes are `"none"`, `"mean"`, `"sum"`, `"mean_batch"`, | ||
| `"sum_batch"`, `"mean_channel"`, and `"sum_channel"`. | ||
| Defaults to `"mean"`. If set to `"none"`, no reduction will be performed. | ||
| get_not_nans: Whether to return the count of non-NaN values. | ||
| If `True`, `aggregate()` returns a tuple (metric, not_nans). Defaults to `False`. | ||
| right: Whether to use the right or left bin edge for binning. Defaults to `False` (left). | ||
|
|
||
| Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. | ||
|
|
||
| Example: | ||
| >>> from monai.transforms import Activations, AsDiscrete | ||
| >>> # Transforms to convert model outputs to probabilities and labels to one-hot | ||
| >>> softmax = Activations(softmax=True) # or sigmoid=True for binary/multi-label | ||
| >>> to_onehot = AsDiscrete(to_onehot=num_classes) | ||
| >>> metric = CalibrationErrorMetric(num_bins=15, include_background=False, calibration_reduction="expected") | ||
| >>> for batch_data in dataloader: | ||
| >>> logits, labels = model(batch_data) | ||
| >>> preds = softmax(logits) # convert logits to probabilities | ||
| >>> labels_onehot = to_onehot(labels) # convert labels to one-hot format | ||
| >>> metric(y_pred=preds, y=labels_onehot) | ||
| >>> ece = metric.aggregate() | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| num_bins: int = 20, | ||
| include_background: bool = True, | ||
| calibration_reduction: CalibrationReduction | str = CalibrationReduction.EXPECTED, | ||
| metric_reduction: MetricReduction | str = MetricReduction.MEAN, | ||
| get_not_nans: bool = False, | ||
| right: bool = False, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.num_bins = num_bins | ||
| self.include_background = include_background | ||
| self.calibration_reduction = CalibrationReduction(calibration_reduction) | ||
| self.metric_reduction = metric_reduction | ||
| self.get_not_nans = get_not_nans | ||
| self.right = right | ||
|
|
||
| def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override] | ||
| """ | ||
| Compute calibration error for the given predictions and ground truth. | ||
|
|
||
| Args: | ||
| y_pred: input data to compute. It should be in the format of (batch, channel, spatial...). | ||
| It represents probability predictions of the model. | ||
| y: ground truth in one-hot format. It should be in the format of (batch, channel, spatial...). | ||
| The values should be binarized. | ||
| **kwargs: additional keyword arguments (unused, for API compatibility). | ||
|
|
||
| Returns: | ||
| Calibration error tensor with shape (batch, channel). | ||
| """ | ||
| if not self.include_background: | ||
| y_pred, y = ignore_background(y_pred=y_pred, y=y) | ||
|
|
||
| mean_p_per_bin, mean_gt_per_bin, bin_counts = calibration_binning( | ||
| y_pred=y_pred, y=y, num_bins=self.num_bins, right=self.right | ||
| ) | ||
|
|
||
| # Calculate the absolute differences, ignoring nan values | ||
| abs_diff = torch.abs(mean_p_per_bin - mean_gt_per_bin) | ||
|
|
||
| if self.calibration_reduction == CalibrationReduction.EXPECTED: | ||
| # Calculate the weighted sum of absolute differences | ||
| return torch.nansum(abs_diff * bin_counts, dim=-1) / torch.sum(bin_counts, dim=-1) | ||
| elif self.calibration_reduction == CalibrationReduction.AVERAGE: | ||
| return torch.nanmean(abs_diff, dim=-1) # Average across all dimensions, ignoring nan | ||
| elif self.calibration_reduction == CalibrationReduction.MAXIMUM: | ||
| abs_diff_no_nan = torch.nan_to_num(abs_diff, nan=0.0) | ||
| return torch.max(abs_diff_no_nan, dim=-1).values # Maximum across all dimensions | ||
| else: | ||
| raise ValueError(f"Unsupported calibration reduction: {self.calibration_reduction}") | ||
|
|
||
| def aggregate( | ||
| self, reduction: MetricReduction | str | None = None | ||
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Execute reduction logic for the output of `_compute_tensor`. | ||
|
|
||
| Args: | ||
| reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, | ||
| available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, | ||
| ``"mean_channel"``, ``"sum_channel"``}, default to `self.metric_reduction`. if "none", will not | ||
| do reduction. | ||
|
|
||
| Returns: | ||
| If `get_not_nans` is True, returns a tuple (metric, not_nans), otherwise returns only the metric. | ||
| """ | ||
| data = self.get_buffer() | ||
| if not isinstance(data, torch.Tensor): | ||
| raise ValueError("the data to aggregate must be PyTorch Tensor.") | ||
|
|
||
| # do metric reduction | ||
| f, not_nans = do_metric_reduction(data, reduction or self.metric_reduction) | ||
| return (f, not_nans) if self.get_not_nans else f | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: Project-MONAI/MONAI
Length of output: 93
🏁 Script executed:
Repository: Project-MONAI/MONAI
Length of output: 2965
🏁 Script executed:
rg "CalibrationReduction" --type py -B 2 -A 2Repository: Project-MONAI/MONAI
Length of output: 6926
🏁 Script executed:
fd "test.*calibration" --type fRepository: Project-MONAI/MONAI
Length of output: 152
🏁 Script executed:
Repository: Project-MONAI/MONAI
Length of output: 12660
🏁 Script executed:
Repository: Project-MONAI/MONAI
Length of output: 1757
🏁 Script executed:
Repository: Project-MONAI/MONAI
Length of output: 149
🏁 Script executed:
Repository: Project-MONAI/MONAI
Length of output: 180
🏁 Script executed:
rg "all.*nan|empty.*bin|nan.*maximum" tests/metrics/test_calibration_metric.py -iRepository: Project-MONAI/MONAI
Length of output: 45
🏁 Script executed:
Repository: Project-MONAI/MONAI
Length of output: 3534
Use sentinel value to preserve NaN when all bins are empty in MAXIMUM reduction
When all bins are empty (all NaN),
nan_to_num(..., nan=0.0)converts NaN values to 0, causing MCE to return 0 instead of NaN. This misrepresents "no valid data" as "zero error". Replace with-infas sentinel and restore NaN for all-NaN cases:Suggested fix
Additionally, add a test case for the all-empty-bins edge case to prevent regression.
📝 Committable suggestion
🤖 Prompt for AI Agents