-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Add CalibrationErrorMetric and CalibrationError handler #8707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Add CalibrationErrorMetric and CalibrationError handler #8707
Conversation
- Add calibration_binning() function for hard binning calibration - Add CalibrationErrorMetric with ECE/ACE/MCE reduction modes - Add CalibrationError Ignite handler - Add comprehensive tests for metrics and handler Addresses Project-MONAI#8505 Signed-off-by: Theo Barfoot <[email protected]>
📝 WalkthroughWalkthroughThis pull request introduces a new calibration error metric system for MONAI. A new metrics module (monai/metrics/calibration.py) implements the core functionality: a Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@monai/metrics/calibration.py`:
- Around line 228-235: In the CalibrationReduction.MAXIMUM branch, don’t convert
NaN to 0 (which hides “no data”); instead use a -inf sentinel when calling
torch.nan_to_num on abs_diff (e.g. nan=-torch.inf), take the max along dim=-1,
then detect buckets that were all-NaN (e.g. all_nan_mask =
torch.isnan(abs_diff).all(dim=-1)) and restore those positions in the result to
NaN; update the method where self.calibration_reduction is checked (the MAXIMUM
branch that uses abs_diff_no_nan) accordingly and add a unit test covering the
“all bins empty” case to prevent regressions.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (6)
monai/handlers/__init__.pymonai/handlers/calibration.pymonai/metrics/__init__.pymonai/metrics/calibration.pytests/handlers/test_handler_calibration_error.pytests/metrics/test_calibration_metric.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/metrics/__init__.pymonai/handlers/__init__.pymonai/handlers/calibration.pytests/handlers/test_handler_calibration_error.pymonai/metrics/calibration.pytests/metrics/test_calibration_metric.py
🧬 Code graph analysis (6)
monai/metrics/__init__.py (1)
monai/metrics/calibration.py (3)
CalibrationErrorMetric(139-260)CalibrationReduction(125-136)calibration_binning(30-122)
monai/handlers/__init__.py (1)
monai/handlers/calibration.py (1)
CalibrationError(23-71)
monai/handlers/calibration.py (1)
monai/utils/enums.py (1)
MetricReduction(239-250)
tests/handlers/test_handler_calibration_error.py (4)
monai/handlers/calibration.py (1)
CalibrationError(23-71)monai/handlers/utils.py (1)
from_engine(170-210)monai/utils/module.py (2)
min_version(273-285)optional_import(315-445)tests/test_utils.py (1)
assert_allclose(119-159)
monai/metrics/calibration.py (4)
monai/metrics/metric.py (1)
CumulativeIterationMetric(296-353)monai/metrics/utils.py (2)
do_metric_reduction(71-130)ignore_background(54-68)monai/utils/enums.py (2)
MetricReduction(239-250)StrEnum(68-90)monai/utils/profiling.py (1)
end(430-432)
tests/metrics/test_calibration_metric.py (3)
monai/metrics/calibration.py (4)
CalibrationErrorMetric(139-260)CalibrationReduction(125-136)calibration_binning(30-122)aggregate(239-260)monai/utils/enums.py (1)
MetricReduction(239-250)monai/metrics/metric.py (1)
get_buffer(282-293)
🪛 Ruff (0.14.11)
tests/handlers/test_handler_calibration_error.py
106-106: Unused function argument: engine
(ARG001)
142-142: Unused function argument: engine
(ARG001)
168-168: Unused function argument: engine
(ARG001)
monai/metrics/calibration.py
23-27: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
71-71: Avoid specifying long messages outside the exception class
(TRY003)
73-73: Avoid specifying long messages outside the exception class
(TRY003)
75-75: Avoid specifying long messages outside the exception class
(TRY003)
204-204: Unused method argument: kwargs
(ARG002)
237-237: Avoid specifying long messages outside the exception class
(TRY003)
256-256: Prefer TypeError exception for invalid type
(TRY004)
256-256: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: packaging
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: build-docs
- GitHub Check: quick-py3 (ubuntu-latest)
🔇 Additional comments (14)
monai/handlers/__init__.py (1)
15-15: LGTM!Import is correctly placed alphabetically and aligns with the new
CalibrationErrorhandler inmonai/handlers/calibration.py.monai/metrics/__init__.py (1)
16-16: LGTM!Public API exports correctly added for the new calibration functionality.
tests/metrics/test_calibration_metric.py (4)
14-25: LGTM!Imports and device setup are appropriate. Good use of
_devicefor CUDA/CPU portability.
142-196: LGTM!Comprehensive binning tests with proper NaN handling and input validation coverage.
199-253: LGTM!All three calibration reduction modes tested with proper isolation from metric reduction.
256-354: LGTM!Good coverage of metric options including
include_background,metric_reduction,get_not_nans, cumulative behavior, and reset.tests/handlers/test_handler_calibration_error.py (3)
19-25: LGTM!Proper optional import pattern for Ignite with version check and skip decorator.
82-122: LGTM!Handler tests properly verify metric computation and details shape. The unused
engineparameter in_val_funcis required by Ignite's callback signature.
124-181: LGTM!Edge case tests cover single iteration and
save_details=Falsebehavior with appropriate defensive checks.monai/handlers/calibration.py (1)
23-71: LGTM!Clean handler implementation following MONAI patterns. Docstring adequately documents all parameters. Consider adding a usage example similar to other handlers if desired.
monai/metrics/calibration.py (4)
30-122: calibration_binning looks solidValidation, binning, and empty-bin NaN handling are clear and consistent with the stated contract.
125-136: Enum values are clearNaming and values match expected calibration reduction modes.
187-203: Init wiring looks goodConfig is stored cleanly and defaults are sensible.
239-260: Aggregate logic is cleanReduction and
get_not_nansbehavior are consistent with MONAI patterns.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
| if self.calibration_reduction == CalibrationReduction.EXPECTED: | ||
| # Calculate the weighted sum of absolute differences | ||
| return torch.nansum(abs_diff * bin_counts, dim=-1) / torch.sum(bin_counts, dim=-1) | ||
| elif self.calibration_reduction == CalibrationReduction.AVERAGE: | ||
| return torch.nanmean(abs_diff, dim=-1) # Average across all dimensions, ignoring nan | ||
| elif self.calibration_reduction == CalibrationReduction.MAXIMUM: | ||
| abs_diff_no_nan = torch.nan_to_num(abs_diff, nan=0.0) | ||
| return torch.max(abs_diff_no_nan, dim=-1).values # Maximum across all dimensions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
find . -type f -name "calibration.py" | grep metricsRepository: Project-MONAI/MONAI
Length of output: 93
🏁 Script executed:
cat -n monai/metrics/calibration.py | head -250 | tail -50Repository: Project-MONAI/MONAI
Length of output: 2965
🏁 Script executed:
rg "CalibrationReduction" --type py -B 2 -A 2Repository: Project-MONAI/MONAI
Length of output: 6926
🏁 Script executed:
fd "test.*calibration" --type fRepository: Project-MONAI/MONAI
Length of output: 152
🏁 Script executed:
cat tests/metrics/test_calibration_metric.pyRepository: Project-MONAI/MONAI
Length of output: 12660
🏁 Script executed:
rg "calibration_binning" -B 5 -A 10 monai/metrics/calibration.py | head -80Repository: Project-MONAI/MONAI
Length of output: 1757
🏁 Script executed:
python3 << 'EOF'
import torch
# Test torch.max behavior with all NaN tensor
all_nan = torch.tensor([[float('nan'), float('nan')], [float('nan'), float('nan')]])
print("All NaN tensor:", all_nan)
print("torch.max(all_nan, dim=-1):", torch.max(all_nan, dim=-1))
# Test nan_to_num with nan=0.0
nan_to_zero = torch.nan_to_num(all_nan, nan=0.0)
print("\nnan_to_num with nan=0.0:", nan_to_zero)
print("torch.max(nan_to_zero, dim=-1):", torch.max(nan_to_zero, dim=-1))
# Test nan_to_num with nan=-inf
nan_to_neginf = torch.nan_to_num(all_nan, nan=-float('inf'))
print("\nnan_to_num with nan=-inf:", nan_to_neginf)
print("torch.max(nan_to_neginf, dim=-1):", torch.max(nan_to_neginf, dim=-1))
# Test torch.nanmax
print("\ntorch.nanmax(all_nan, dim=-1):", torch.nanmax(all_nan, dim=-1))
# Test comparison for restoration
max_vals = torch.max(nan_to_neginf, dim=-1).values
print("\nmax_vals after max(nan_to_neginf):", max_vals)
print("max_vals == -inf:", max_vals == -float('inf'))
max_vals[max_vals == -float('inf')] = float('nan')
print("After restoring NaN:", max_vals)
EOFRepository: Project-MONAI/MONAI
Length of output: 149
🏁 Script executed:
python3 << 'EOF'
import torch
# Test 1: torch.max behavior with all NaN tensor
print("=== Test 1: All NaN tensor ===")
all_nan = torch.tensor([[float('nan'), float('nan')], [float('nan'), float('nan')]])
print("Input:", all_nan)
result = torch.max(all_nan, dim=-1)
print("torch.max(all_nan, dim=-1).values:", result.values)
print()
# Test 2: nan_to_num with nan=0.0
print("=== Test 2: nan_to_num with nan=0.0 ===")
nan_to_zero = torch.nan_to_num(all_nan, nan=0.0)
print("After nan_to_num(nan=0.0):", nan_to_zero)
print("torch.max result:", torch.max(nan_to_zero, dim=-1).values)
print()
# Test 3: nan_to_num with nan=-inf
print("=== Test 3: nan_to_num with nan=-inf ===")
nan_to_neginf = torch.nan_to_num(all_nan, nan=-float('inf'))
print("After nan_to_num(nan=-inf):", nan_to_neginf)
max_result = torch.max(nan_to_neginf, dim=-1).values
print("torch.max result:", max_result)
print()
# Test 4: Restore NaN after max
print("=== Test 4: Restore NaN ===")
max_vals = max_result.clone()
print("Original max_vals:", max_vals)
max_vals[max_vals == -float('inf')] = float('nan')
print("After restoring NaN:", max_vals)
print()
# Test 5: Check if torch.nanmax exists and works
print("=== Test 5: torch.nanmax ===")
try:
result = torch.nanmax(all_nan, dim=-1)
print("torch.nanmax(all_nan, dim=-1):", result)
except Exception as e:
print("Error:", e)
EOFRepository: Project-MONAI/MONAI
Length of output: 180
🏁 Script executed:
rg "all.*nan|empty.*bin|nan.*maximum" tests/metrics/test_calibration_metric.py -iRepository: Project-MONAI/MONAI
Length of output: 45
🏁 Script executed:
sed -n '1,150p' monai/metrics/calibration.py | tail -80Repository: Project-MONAI/MONAI
Length of output: 3534
Use sentinel value to preserve NaN when all bins are empty in MAXIMUM reduction
When all bins are empty (all NaN), nan_to_num(..., nan=0.0) converts NaN values to 0, causing MCE to return 0 instead of NaN. This misrepresents "no valid data" as "zero error". Replace with -inf as sentinel and restore NaN for all-NaN cases:
Suggested fix
elif self.calibration_reduction == CalibrationReduction.MAXIMUM:
- abs_diff_no_nan = torch.nan_to_num(abs_diff, nan=0.0)
- return torch.max(abs_diff_no_nan, dim=-1).values # Maximum across all dimensions
+ abs_diff_no_nan = torch.nan_to_num(abs_diff, nan=-torch.inf)
+ max_vals = torch.max(abs_diff_no_nan, dim=-1).values
+ max_vals[max_vals == -torch.inf] = torch.nan
+ return max_vals # Maximum across valid binsAdditionally, add a test case for the all-empty-bins edge case to prevent regression.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if self.calibration_reduction == CalibrationReduction.EXPECTED: | |
| # Calculate the weighted sum of absolute differences | |
| return torch.nansum(abs_diff * bin_counts, dim=-1) / torch.sum(bin_counts, dim=-1) | |
| elif self.calibration_reduction == CalibrationReduction.AVERAGE: | |
| return torch.nanmean(abs_diff, dim=-1) # Average across all dimensions, ignoring nan | |
| elif self.calibration_reduction == CalibrationReduction.MAXIMUM: | |
| abs_diff_no_nan = torch.nan_to_num(abs_diff, nan=0.0) | |
| return torch.max(abs_diff_no_nan, dim=-1).values # Maximum across all dimensions | |
| if self.calibration_reduction == CalibrationReduction.EXPECTED: | |
| # Calculate the weighted sum of absolute differences | |
| return torch.nansum(abs_diff * bin_counts, dim=-1) / torch.sum(bin_counts, dim=-1) | |
| elif self.calibration_reduction == CalibrationReduction.AVERAGE: | |
| return torch.nanmean(abs_diff, dim=-1) # Average across all dimensions, ignoring nan | |
| elif self.calibration_reduction == CalibrationReduction.MAXIMUM: | |
| abs_diff_no_nan = torch.nan_to_num(abs_diff, nan=-torch.inf) | |
| max_vals = torch.max(abs_diff_no_nan, dim=-1).values | |
| max_vals[max_vals == -torch.inf] = torch.nan | |
| return max_vals # Maximum across valid bins |
🤖 Prompt for AI Agents
In `@monai/metrics/calibration.py` around lines 228 - 235, In the
CalibrationReduction.MAXIMUM branch, don’t convert NaN to 0 (which hides “no
data”); instead use a -inf sentinel when calling torch.nan_to_num on abs_diff
(e.g. nan=-torch.inf), take the max along dim=-1, then detect buckets that were
all-NaN (e.g. all_nan_mask = torch.isnan(abs_diff).all(dim=-1)) and restore
those positions in the result to NaN; update the method where
self.calibration_reduction is checked (the MAXIMUM branch that uses
abs_diff_no_nan) accordingly and add a unit test covering the “all bins empty”
case to prevent regressions.
Description
Addresses #8505
Overview
This PR adds calibration error metrics and an Ignite handler to MONAI, enabling users to evaluate and monitor model calibration for segmentation and other multi-class probabilistic tasks with shape
(B, C, spatial...).What's Included
1. Calibration Metrics (
monai/metrics/calibration.py)calibration_binning(): Core function to compute calibration bins with mean predictions, mean ground truths, and bin counts. Exported to support research workflows where users need per-bin statistics for plotting reliability diagrams.CalibrationReduction: Enum supporting three reduction methods:EXPECTED- Expected Calibration Error (ECE): weighted average by bin countAVERAGE- Average Calibration Error (ACE): simple average across binsMAXIMUM- Maximum Calibration Error (MCE): maximum error across binsCalibrationErrorMetric: ACumulativeIterationMetricsubclass supporting:include_background)mean,sum,mean_batch, etc.)2. Ignite Handler (
monai/handlers/calibration.py)CalibrationError: AnIgniteMetricHandlerwrapper that:save_detailsfor per-sample/per-channel metric details via the metric buffer3. Comprehensive Tests
tests/metrics/test_calibration_metric.py: Tests covering:tests/handlers/test_handler_calibration_error.py: Tests covering:engine.run()save_detailsfunctionalityPublic API
Exposes the following via
monai.metrics:CalibrationErrorMetricCalibrationReductioncalibration_binningExposes via
monai.handlers:CalibrationErrorImplementation Notes
scatter_add+ counts instead ofscatter_reduce("mean")for better PyTorch version compatibilitytorch.nan_to_numinstead of in-place operations for cleaner codeRelated Work
The algorithmic approach follows the calibration metrics from Average-Calibration-Losses, with related publications:
Future Work
As discussed in the issue, calibration losses will be added in a separate PR to keep changes focused and easier to review.
Checklist
__init__.pyfilesExample Usage
With Ignite Handler