-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Add sigmoid/softmax interface for AsymmetricUnifiedFocalLoss #8669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
WalkthroughIntroduces Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (2)
monai/losses/unified_focal_loss.py (2)
76-82: Addstacklevel=2to warning.Per static analysis and Python conventions, set stacklevel to point to caller.
- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
177-194: Logic is correct; consider documenting return value.The forward method correctly passes logits to FocalLoss and probabilities to AsymmetricFocalTverskyLoss. Per coding guidelines, docstrings should document return values.
""" Args: y_pred: (BNH[WD]) Logits (raw scores). y_true: (BNH[WD]) Ground truth labels. + + Returns: + torch.Tensor: Weighted combination of focal loss and asymmetric focal Tversky loss. """
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
FocalLoss(26-202)monai/networks/utils.py (1)
one_hot(170-220)monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py
78-78: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
85-85: Avoid specifying long messages outside the exception class
(TRY003)
127-127: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (1)
monai/losses/unified_focal_loss.py (1)
36-44: Tests already cover the unified focal loss implementation.New tests were added to cover the changes. The PR indicates that test coverage has been implemented, so this concern can be closed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
monai/losses/unified_focal_loss.py (2)
157-157: Replace Chinese comment with English.- self.use_softmax = use_softmax # 儲存參數 + self.use_softmax = use_softmax
112-114: Numerical instability when dice approaches 1.0.When
dice_class[:, i]equals 1.0,torch.pow(0, -self.gamma)produces infinity, causing NaN gradients.Proposed fix
- # Foreground classes: apply focal modulation - # Original logic: (1 - dice) * (1 - dice)^(-gamma) -> (1-dice)^(1-gamma) - loss_list.append((1 - dice_class[:, i]) * torch.pow(1 - dice_class[:, i], -self.gamma)) + # Foreground classes: apply focal modulation + back_dice = torch.clamp(1 - dice_class[:, i], min=self.epsilon) + loss_list.append(back_dice * torch.pow(back_dice, -self.gamma))
🧹 Nitpick comments (1)
monai/losses/unified_focal_loss.py (1)
74-80: Background exclusion logic is correct.The implementation properly handles the single-channel edge case.
Optional: Add stacklevel to warning
- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
FocalLoss(26-202)monai/networks/utils.py (1)
one_hot(170-220)monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py
76-76: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
83-83: Avoid specifying long messages outside the exception class
(TRY003)
125-125: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (3)
monai/losses/unified_focal_loss.py (3)
34-58: Constructor changes look good.The
include_backgroundparameter addition aligns with MONAI conventions and is properly documented.
159-173: Component initialization is appropriate.Both loss components are properly configured with consistent parameters.
181-188: Remove redundantto_onehot_yconversion concern.Both
focal_lossandasy_focal_tversky_lossindependently applyto_onehot_yconversion when enabled. Since PyTorch's one-hot encoding creates new tensors (not in-place modifications), passing the samey_truereference to both components is safe and functionally correct. No conversion conflicts occur.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
monai/losses/unified_focal_loss.py (2)
76-76: Addstacklevel=2to warning.Missing explicit stacklevel for proper traceback.
- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)Based on static analysis hints.
176-193: Document thaty_predmust be logits.The forward method expects raw logits since
FocalLossapplies activation internally, and manual softmax/sigmoid is applied for the Tversky component. The docstring states this (line 179) but could be more explicit about the consequences of passing probabilities instead.Consider adding a note in the docstring:
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """ Args: - y_pred: (BNH[WD]) Logits (raw scores). + y_pred: (BNH[WD]) Logits (raw scores, not probabilities). + Do not pass pre-activated inputs; activation is applied internally. y_true: (BNH[WD]) Ground truth labels. """tests/losses/test_unified_focal_loss.py (1)
26-61: Add test coverage for edge cases.Current tests only cover perfect predictions with zero loss. Missing coverage for:
- Imperfect predictions (non-zero loss)
include_background=Falsescenariosto_onehot_y=Truewith integer labels- Multi-class softmax with imperfect predictions
Suggested additional test cases
# Case 2: Binary with include_background=False [ { "use_softmax": False, "include_background": False, }, { "y_pred": torch.tensor([[[[logit_pos, logit_neg], [logit_neg, logit_pos]]]]), "y_true": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]), }, 0.0, # Should still be zero for perfect prediction ], # Case 3: Multi-class with to_onehot_y=True (integer labels) [ { "use_softmax": True, "include_background": True, "to_onehot_y": True, }, { "y_pred": torch.tensor([[[[logit_pos, logit_neg], [logit_neg, logit_neg], [logit_neg, logit_pos]]]]), "y_true": torch.tensor([[[[0, 2]]]]), # Integer labels: class 0, class 2 }, 0.0, ],Do you want me to generate a complete test case addition or open an issue to track this?
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/losses/unified_focal_loss.pytests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/losses/test_unified_focal_loss.pymonai/losses/unified_focal_loss.py
🧬 Code graph analysis (2)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
AsymmetricUnifiedFocalLoss(129-193)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
FocalLoss(26-202)monai/networks/utils.py (1)
one_hot(170-220)monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py
76-76: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
83-83: Avoid specifying long messages outside the exception class
(TRY003)
126-126: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (5)
monai/losses/unified_focal_loss.py (3)
184-187: LGTM: Correct activation choice for Tversky loss.FocalLoss handles its own activation internally, so this manual conversion to probabilities for AsymmetricFocalTverskyLoss is correct. The activation choice (softmax vs sigmoid) properly follows the
use_softmaxflag.
89-116: Implementation correctly handlesinclude_backgroundwith standard MONAI slicing pattern.When
include_background=False, channel index 0 is excluded from the calculation—the code does this via tensor slicing at lines 79-80 before the asymmetry loop. Once sliced, all remaining channels receive focal modulation; none are treated as background. The loss only supports binary segmentation, so asymmetry designates the first present channel as background and all others as foreground, which is the intended behavior per the documented design comment (lines 101-104).
160-174: Both composed losses independently transformy_truewith their respective settings. Each applies its own non-destructive transformations (one-hot encoding creates new tensors; slicing creates new views), so no actual collision occurs. This is correct by design—composed losses should handle their own input transformations.tests/losses/test_unified_focal_loss.py (2)
22-24: LGTM: High-confidence logits ensure clear test expectations.Using ±10.0 logits produces near-perfect probabilities (~0.9999 and ~0.0001), making zero-loss expectations reasonable for perfect predictions.
77-89: LGTM: CUDA test correctly instantiates loss.The test properly moves the loss module to CUDA (line 85), ensuring both model parameters and inputs are on the same device.
39664ea to
41dccad
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
monai/losses/unified_focal_loss.py (1)
74-80: Background exclusion correctly implemented.The logic properly removes the first channel when
include_background=False, consistent with FocalLoss. The single-channel warning is appropriate.Optional: Add stacklevel to warning for better traceability
- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/losses/unified_focal_loss.pytests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.pytests/losses/test_unified_focal_loss.py
🧬 Code graph analysis (2)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
FocalLoss(26-202)monai/networks/utils.py (1)
one_hot(170-220)monai/utils/enums.py (1)
LossReduction(253-264)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
AsymmetricUnifiedFocalLoss(129-193)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py
76-76: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
83-83: Avoid specifying long messages outside the exception class
(TRY003)
126-126: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (8)
tests/losses/test_unified_focal_loss.py (3)
22-42: Binary test case is correct.The high-confidence logits (±10.0) correctly produce near-perfect probabilities after sigmoid. The alignment between predictions and targets should yield near-zero loss.
66-70: Test structure is correct.Parameterized test properly unpacks configuration and data, with appropriate numerical tolerances for floating-point comparison.
77-89: CUDA test correctly adapted to logits interface.The test properly uses logits with
use_softmax=Falsefor binary segmentation and correctly moves both tensors and the loss module to CUDA.monai/losses/unified_focal_loss.py (5)
19-19: Import is correct.FocalLoss is properly imported from monai.losses for reuse in the unified loss.
34-58: Constructor properly extended with include_background.The parameter is correctly documented, defaulted, and stored for use in the forward method, consistent with MONAI's loss interface patterns.
97-117: Asymmetric focal modulation correctly implemented.Background class uses standard Dice loss while foreground classes apply focal modulation
(1-dice)^(1-gamma). Clamping prevents numerical instability when dice approaches 1.0.
135-174: Composition pattern correctly implemented.The constructor properly instantiates and configures both FocalLoss and AsymmetricFocalTverskyLoss components with shared parameters, enabling modular loss computation.
176-193: Forward method correctly combines losses.FocalLoss operates on logits (with internal activation), while AsymmetricFocalTverskyLoss requires probabilities. The explicit softmax/sigmoid conversion for the Tversky component is correct, and the weighted combination is straightforward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/losses/test_unified_focal_loss.py (1)
26-56: Add test case forinclude_background=False.Test cases cover sigmoid/softmax modes correctly, but the
include_backgroundparameter (added per PR objectives) is only tested withTrue. Add a multi-class case withinclude_background=Falseto validate background channel exclusion.Example test case
# Case 2: Multi-class with background excluded [ { "use_softmax": True, "include_background": False, }, { "y_pred": torch.tensor([[[[logit_pos, logit_neg]], [[logit_neg, logit_pos]], [[logit_neg, logit_neg]]]]), "y_true": torch.tensor([[[[1.0, 0.0]], [[0.0, 1.0]], [[0.0, 0.0]]]]), }, 0.0, ],
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/losses/test_unified_focal_loss.py
🧬 Code graph analysis (1)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
AsymmetricUnifiedFocalLoss(129-193)
🔇 Additional comments (3)
tests/losses/test_unified_focal_loss.py (3)
22-24: LGTM - Clear test constants.Helper logits are well-defined for creating high-confidence predictions.
62-65: LGTM - Parameterized test structure correct.Test method properly unpacks config and data dicts.
72-84: LGTM - CUDA test properly implemented.Test correctly uses logits and moves both tensors and loss module to CUDA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/losses/test_unified_focal_loss.py (1)
79-82: Add docstring.Per coding guidelines, add a docstring describing that this test validates shape mismatch error handling.
🧹 Nitpick comments (1)
tests/losses/test_unified_focal_loss.py (1)
26-68: Add at least one test with non-zero loss.All test cases expect 0.0 loss with perfect predictions. Add a case with imperfect predictions (e.g., logits slightly off from ground truth) to verify the loss is actually computed, not just validating tensor shape compatibility.
Optional: Expand parameter coverage
Consider adding test cases that vary:
to_onehot_y=Truewith class-index format ground truthweight,delta,gammato non-default valuesreductionmodes (SUM, NONE)These are optional and can be deferred.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/losses/test_unified_focal_loss.py
🧬 Code graph analysis (1)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
AsymmetricUnifiedFocalLoss(129-193)
🔇 Additional comments (1)
tests/losses/test_unified_focal_loss.py (1)
22-24: LGTM—High-confidence logits for perfect-prediction tests.Values create predictions very close to 0 or 1, suitable for validating near-zero loss on ideal inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tests/losses/test_unified_focal_loss.py (1)
26-68: Test coverage is adequate for the new interface.The three cases properly exercise binary sigmoid, multi-class softmax, and background exclusion. All use perfect predictions (loss=0.0), which validates the happy path.
Consider adding at least one test case with imperfect predictions (non-zero expected loss) to verify the actual loss computation, not just the zero-loss edge case. This would increase confidence in the refactored implementation.
monai/losses/unified_focal_loss.py (1)
134-173: Excellent refactoring using composition.Replacing inline logic with composed
FocalLossandAsymmetricFocalTverskyLossimproves maintainability and reusability. Parameters are correctly forwarded to both components.Per past review feedback, the
use_softmaxdocstring (lines 151-152) could add brief guidance: softmax for mutually exclusive classes (standard multi-class), sigmoid for multi-label/overlapping classes.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/losses/unified_focal_loss.pytests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/losses/test_unified_focal_loss.pymonai/losses/unified_focal_loss.py
🧬 Code graph analysis (2)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
AsymmetricUnifiedFocalLoss(128-195)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
FocalLoss(26-202)monai/networks/utils.py (1)
one_hot(170-220)monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py
83-83: Avoid specifying long messages outside the exception class
(TRY003)
125-125: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (9)
tests/losses/test_unified_focal_loss.py (3)
22-24: LGTM.Helper constants are well-commented and appropriate for generating high-confidence predictions in tests.
74-85: LGTM.Docstring properly documents parameters. Parameterized test structure is clean and flexible.
92-104: LGTM.CUDA test correctly uses the new API with sigmoid activation and validates GPU compatibility.
monai/losses/unified_focal_loss.py (6)
19-19: LGTM.Import required for the new composition-based implementation.
34-58: LGTM.Adding
include_backgroundparameter aligns with MONAI loss function conventions and enables proper multi-class segmentation support.
74-80: LGTM.Background exclusion logic correctly follows the FocalLoss pattern, including the single-channel warning.
106-114: Asymmetry logic is correct.Background channel (index 0 when
include_background=True) uses standard Dice loss, while foreground channels use focal modulation. Wheninclude_background=False, all channels receive focal modulation since background was removed. Clamping prevents numerical instability.
118-125: LGTM.Reduction logic correctly handles MEAN, SUM, and NONE cases with appropriate error for unsupported values.
175-195: LGTM.Forward pass correctly handles different input requirements: logits for
FocalLoss(which applies activation internally), probabilities forAsymmetricFocalTverskyLoss. The weighted combination is straightforward and matches the documented formula.
2f4657e to
e63e36e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/losses/test_unified_focal_loss.py (1)
117-128: Add docstring for CUDA test.Per coding guidelines and past review comment, add a docstring describing this test's purpose: validating CUDA compatibility with perfect predictions.
Suggested docstring
def test_with_cuda(self): + """Test AsymmetricUnifiedFocalLoss CUDA compatibility with perfect predictions.""" loss = AsymmetricUnifiedFocalLoss()
🧹 Nitpick comments (3)
tests/losses/test_unified_focal_loss.py (1)
25-93: Suggest adding imperfect prediction test cases.All three cases test perfect predictions (loss=0.0). Add at least one case with misaligned logits/labels to verify the loss computes non-zero values correctly and gradients flow properly.
Example imperfect case
[ # Case 3: Imperfect prediction {"use_softmax": False, "include_background": True}, { "y_pred": torch.tensor([[[[0.0, -2.0], [2.0, 0.0]]]]), # Moderate confidence "y_true": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]), }, # Expected: positive loss value (compute actual expected value) ],monai/losses/unified_focal_loss.py (2)
60-86: LGTM: Background exclusion logic is correct.The
include_backgroundhandling properly slices channel 0 from both tensors and warns on single-channel edge cases. Shape validation and clipping are correctly placed.Note: Static analysis flags line 83 for a long exception message (TRY003). Consider a custom exception class if this pattern recurs, but current usage is acceptable.
88-125: LGTM: Asymmetric focal Tversky logic is sound.The per-class loss correctly applies standard Tversky to background (when included) and focal-modulated Tversky to foreground. Clamping prevents numerical instability. Reduction handling is complete.
Static analysis flags line 125 for a long exception message (TRY003). Consider extracting to a constant or custom exception if this pattern is reused.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/losses/unified_focal_loss.pytests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/losses/test_unified_focal_loss.pymonai/losses/unified_focal_loss.py
🧬 Code graph analysis (2)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
AsymmetricUnifiedFocalLoss(128-197)
monai/losses/unified_focal_loss.py (3)
monai/losses/focal_loss.py (1)
FocalLoss(26-202)monai/networks/utils.py (1)
one_hot(170-220)monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py
83-83: Avoid specifying long messages outside the exception class
(TRY003)
125-125: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (6)
tests/losses/test_unified_focal_loss.py (2)
22-23: LGTM: Clear test constants.Module-level logit constants are well-named and appropriate for testing high-confidence predictions.
99-110: LGTM: Well-documented parameterized test.Docstring and implementation are clear. Tolerance levels are appropriate.
monai/losses/unified_focal_loss.py (4)
19-19: LGTM: Import supports composition pattern.FocalLoss import enables the refactored AsymmetricUnifiedFocalLoss to reuse existing focal loss implementation.
34-58: LGTM: Consistent API with include_background parameter.Addition of
include_backgroundaligns with MONAI loss conventions. DefaultTruepreserves backward compatibility.
128-175: LGTM: Clean composition of focal components.Refactoring to compose
FocalLossandAsymmetricFocalTverskyLosseliminates code duplication and ensures consistent parameter handling. Docstrings clearly distinguish sigmoid vs. softmax use cases.
177-197: LGTM: Forward pass correctly combines loss components.The focal loss operates on logits while the Tversky component operates on probabilities after explicit activation. Each component independently handles
to_onehot_yandinclude_background, ensuring correct behavior. Weighted combination is straightforward.
e63e36e to
ccc5459
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
monai/losses/unified_focal_loss.py (1)
157-157: Remove Chinese comment.- self.use_softmax = use_softmax # 儲存參數 + self.use_softmax = use_softmaxThis was previously flagged but remains in the code.
🧹 Nitpick comments (1)
monai/losses/unified_focal_loss.py (1)
118-125: Reduction logic is correct.Standard reduction pattern implemented properly.
For consistency with MONAI style, consider extracting the long error message to a constant or shortening it (static analysis hint TRY003):
- raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + raise ValueError(f"Unsupported reduction: {self.reduction}")
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/losses/unified_focal_loss.pytests/losses/test_unified_focal_loss.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.8)
monai/losses/unified_focal_loss.py
76-76: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
83-83: Avoid specifying long messages outside the exception class
(TRY003)
125-125: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (5)
monai/losses/unified_focal_loss.py (5)
19-19: LGTM.Import required for the new composition-based implementation.
34-59: LGTM.The
include_backgroundparameter is properly documented and maintains backward compatibility withTrueas default.
74-81: Background exclusion logic is correct.The warning for single-channel predictions and slicing logic are appropriate.
However, add
stacklevel=2to the warning at line 76 for proper caller identification:- warnings.warn("single channel prediction, `include_background=False` ignored.") + warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)Based on static analysis hints.
159-173: Asymmetric gamma application is intentional and correct.The Unified Focal Loss design intentionally exploits gamma asymmetry to enable simultaneous suppression and enhancement effects in its component losses. In FocalLoss, gamma down-weights easy-to-classify pixels, while in Focal Tversky Loss, gamma enhances rather than suppresses easy examples. Gamma controls weights for difficult-to-predict samples; distribution-based corrections apply sample-by-sample while region-based corrections apply class-by-class during macro-averaging. This composition pattern correctly implements the unified focal loss framework.
175-192: Forward implementation is correct.The loss properly:
- Computes focal loss on logits
- Converts logits to probabilities for Tversky component via softmax or sigmoid
- Combines losses with configurable weighting
Test coverage includes both sigmoid and softmax activation paths with appropriate input dimensions.
1d196dc to
edb01ce
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (1)
tests/losses/test_unified_focal_loss.py (1)
93-106: Add docstring.Per coding guidelines, add a docstring describing the test purpose.
🔎 Proposed fix
def test_with_cuda(self): + """Validate CUDA compatibility of AsymmetricUnifiedFocalLoss.""" if not torch.cuda.is_available():Based on coding guidelines.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/losses/unified_focal_loss.pytests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.pytests/losses/test_unified_focal_loss.py
🧬 Code graph analysis (2)
monai/losses/unified_focal_loss.py (1)
monai/utils/enums.py (1)
LossReduction(253-264)
tests/losses/test_unified_focal_loss.py (3)
monai/losses/unified_focal_loss.py (1)
AsymmetricUnifiedFocalLoss(229-298)tests/test_utils.py (1)
assert_allclose(119-159)monai/networks/nets/quicknat.py (1)
is_cuda(433-437)
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py
83-83: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
89-89: Avoid specifying long messages outside the exception class
(TRY003)
188-188: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
194-194: Avoid specifying long messages outside the exception class
(TRY003)
291-291: Avoid specifying long messages outside the exception class
(TRY003)
tests/losses/test_unified_focal_loss.py
83-83: Unused method argument: expected_val
(ARG002)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: build-docs
🔇 Additional comments (7)
tests/losses/test_unified_focal_loss.py (2)
21-66: Test case definitions are well-structured.The test cases cover binary logits, 2-channel binary, and multi-class scenarios with appropriate shapes and parameter combinations. The use of 10.0/-10.0 logits ensures near-perfect probabilities for validation.
71-80: LGTM.Tolerance of 1e-3 is appropriate given that logits of ±10.0 don't yield exact probabilities of 0.0/1.0.
monai/losses/unified_focal_loss.py (5)
34-60: LGTM.The
use_softmaxparameter is properly integrated with clear documentation.
91-129: Loss calculations are correct.The background dice and foreground focal-tversky computations align with the paper's formulation. The use of
1/gammaexponent for foreground classes properly implements the focal modulation.
196-226: Focal loss implementation is correct.The asymmetric weighting (background focal, foreground standard CE) with
deltabalancing correctly addresses class imbalance.
237-279: Composition pattern is well-executed.Creating internal loss instances with shared parameters ensures consistency and avoids duplication in the forward pass.
281-298: Forward logic is sound.The shape validation correctly handles edge cases (binary logits, to_onehot_y), and the weighted combination properly unifies focal and tversky losses.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)
240-240:num_classesparameter is unused.The
num_classesparameter is stored at line 260 but never referenced. Either remove it or use it.🔎 Proposed fix to remove unused parameter
def __init__( self, to_onehot_y: bool = False, - num_classes: int = 2, weight: float = 0.5, gamma: float = 0.5, delta: float = 0.7, reduction: LossReduction | str = LossReduction.MEAN, use_softmax: bool = False, ): """ Args: to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. - num_classes: number of classes. Defaults to 2. weight: weight factor to balance between Focal Loss and Tversky Loss.And remove
self.num_classes = num_classesat line 260.Also applies to: 260-260
♻️ Duplicate comments (2)
monai/losses/unified_focal_loss.py (2)
83-83: Addstacklevel=2to warning.Per static analysis, add
stacklevel=2so the warning points to the caller.- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
188-188: Addstacklevel=2to warning.Per static analysis, add
stacklevel=2.- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
🧹 Nitpick comments (6)
monai/losses/unified_focal_loss.py (6)
122-129: Unreachable fallback return.The final
return torch.mean(all_losses)at line 129 is unreachable for validLossReductionvalues. Consider raising an error for invalid reductions or removing the redundant return.🔎 Proposed fix
if self.reduction == LossReduction.MEAN.value: return torch.mean(all_losses) if self.reduction == LossReduction.SUM.value: return torch.sum(all_losses) if self.reduction == LossReduction.NONE.value: return all_losses - - return torch.mean(all_losses) + raise ValueError(f"Unsupported reduction: {self.reduction}")
158-159: Incomplete docstring forreductionparameter.The
reductionparameter docstring is missing its description.- reduction: {``"none"``, ``"mean"``, ``"sum"``} - use_softmax: whether to use softmax to transform logits. Defaults to False. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + use_softmax: whether to use softmax to transform logits. Defaults to False.
175-202: Consider extracting shared preprocessing logic.Lines 175-202 duplicate the preprocessing from
AsymmetricFocalTverskyLoss(lines 70-99): single-channel handling, one-hot conversion, shape validation, and probability conversion. Extract to a shared helper to reduce duplication.
293-294: Sub-losses preprocess inputs independently, causing duplicate work.Both
self.asy_focal_lossandself.asy_focal_tversky_lossindependently apply sigmoid/softmax, one-hot encoding, and clamping to the same inputs. For performance, consider preprocessing once in thisforwardmethod and passing processed tensors to sub-losses configured to skip preprocessing.
62-68: Docstring should document return value and exceptions.Per coding guidelines, docstrings should describe return value and raised exceptions.
""" Args: y_pred: prediction logits or probabilities. Shape should be (B, C, spatial_dims). y_true: ground truth labels. Shape should match y_pred. + + Returns: + torch.Tensor: Computed loss. Shape depends on reduction setting. + + Raises: + ValueError: If y_true and y_pred shapes don't match after preprocessing. """
168-173: Docstring should document return value and exceptions.Same as
AsymmetricFocalTverskyLoss.forward- add Returns and Raises sections.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
LossReduction(253-264)monai/networks/utils.py (1)
one_hot(170-220)
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py
83-83: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
89-89: Avoid specifying long messages outside the exception class
(TRY003)
188-188: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
194-194: Avoid specifying long messages outside the exception class
(TRY003)
291-291: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (2)
monai/losses/unified_focal_loss.py (2)
288-291: Shape validation is correct but duplicated.The shape check here is a reasonable early guard, though sub-losses will validate again after their own preprocessing. Acceptable as-is for fail-fast behavior.
296-298: LGTM!Clean delegation to sub-losses with proper weighted combination.
Signed-off-by: ytl0623 <[email protected]>
4166faa to
45d9877
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)
237-279: Remove unusednum_classesparameter or document for backward compatibility.The
num_classesparameter is stored at line 260 but never used. It's not passed to internal loss instances and not referenced in the forward method. Either remove it or add a comment explaining why it's retained.
♻️ Duplicate comments (4)
monai/losses/unified_focal_loss.py (2)
186-191: Addstacklevel=2to warning.The
warnings.warncall should specifystacklevel=2so the warning points to the caller's code.🔎 Proposed fix
if self.to_onehot_y: if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) else:Based on static analysis hint.
81-86: Addstacklevel=2to warning.The
warnings.warncall should specifystacklevel=2so the warning points to the caller's code.🔎 Proposed fix
if self.to_onehot_y: if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) else:Based on static analysis hint.
tests/losses/test_unified_focal_loss.py (2)
94-107: Add docstring.Per coding guidelines, add a docstring describing that this test validates CUDA compatibility of the loss.
🔎 Proposed fix
def test_with_cuda(self): + """Verify CUDA compatibility by running loss on GPU tensors when available.""" if not torch.cuda.is_available():
83-87: Remove unused parameter and add docstring.The
expected_valparameter is unused. Remove it from the signature and update the test case accordingly. Also add a docstring per coding guidelines.🔎 Proposed fix
@parameterized.expand([TEST_CASE_MULTICLASS_WRONG]) - def test_wrong_prediction(self, input_data, expected_val, args): + def test_wrong_prediction(self, input_data, args): + """Verify that wrong predictions yield high loss values.""" loss_func = AsymmetricUnifiedFocalLoss(**args) result = loss_func(**input_data) self.assertGreater(result.item(), 1.0, "Loss should be high for wrong predictions")Update TEST_CASE_MULTICLASS_WRONG at line 62 to remove the None value:
TEST_CASE_MULTICLASS_WRONG = [ { "y_pred": torch.tensor( [[[[-10.0, -10.0], [-10.0, -10.0]], [[10.0, 10.0], [10.0, 10.0]], [[-10.0, -10.0], [-10.0, -10.0]]]] ), "y_true": torch.tensor([[[[0, 0], [0, 0]]]]), # GT is class 0, but Pred is class 1 }, - None, {"use_softmax": True, "to_onehot_y": True}, ]Based on static analysis hint.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/losses/unified_focal_loss.pytests/losses/test_unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/losses/test_unified_focal_loss.pymonai/losses/unified_focal_loss.py
🧬 Code graph analysis (2)
tests/losses/test_unified_focal_loss.py (1)
monai/losses/unified_focal_loss.py (1)
AsymmetricUnifiedFocalLoss(229-298)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
LossReduction(253-264)monai/networks/utils.py (1)
one_hot(170-220)
🪛 Ruff (0.14.10)
tests/losses/test_unified_focal_loss.py
84-84: Unused method argument: expected_val
(ARG002)
monai/losses/unified_focal_loss.py
83-83: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
89-89: Avoid specifying long messages outside the exception class
(TRY003)
188-188: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
194-194: Avoid specifying long messages outside the exception class
(TRY003)
291-291: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.12)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
monai/losses/unified_focal_loss.py (2)
81-86: Addstacklevel=2to warning (still open from past reviews).🔎 Proposed fix
if self.to_onehot_y: if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) else:Based on static analysis hint.
186-191: Addstacklevel=2to warning (still open from past reviews).🔎 Proposed fix
if self.to_onehot_y: if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) else:Based on static analysis hint.
🧹 Nitpick comments (3)
monai/losses/unified_focal_loss.py (3)
110-129: Loss computation is correct but consider documenting asymmetric treatment.Background uses standard Dice loss while foreground uses focal modulation. This asymmetry is intentional to prioritize foreground classes, but could benefit from an inline comment for future maintainers.
🔎 Optional documentation enhancement
# Calculate losses separately for each class - # Background: Standard Dice Loss + # Background: Standard Dice Loss (no focal modulation to preserve sensitivity) back_dice = 1 - dice_class[:, 0] - # Foreground: Focal Tversky Loss + # Foreground: Focal Tversky Loss (focal modulation to down-weight easy examples) fore_dice = torch.pow(1 - dice_class[:, 1:], 1 / self.gamma)
175-182: Optional: Extract duplicated single-channel handling to helper function.The same single-channel auto-conversion logic appears in both AsymmetricFocalTverskyLoss (lines 69-78) and AsymmetricFocalLoss. Consider extracting to a shared helper if more losses adopt this pattern.
266-279: Consider exposing separate gamma parameters for the two loss components.AsymmetricFocalLoss defaults to
gamma=2.0while AsymmetricFocalTverskyLoss defaults togamma=0.75, but AsymmetricUnifiedFocalLoss forces both to use the samegammavalue. This prevents users from independently tuning focal modulation for distribution-based (CE) vs region-based (Dice) objectives.Not blocking, but consider adding
gamma_focalandgamma_tverskyparameters in a future revision if users request finer control.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
LossReduction(253-264)monai/networks/utils.py (1)
one_hot(170-220)
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py
83-83: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
89-89: Avoid specifying long messages outside the exception class
(TRY003)
188-188: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
194-194: Avoid specifying long messages outside the exception class
(TRY003)
291-291: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (3)
monai/losses/unified_focal_loss.py (3)
69-78: Verify single-channel sigmoid conversion produces valid probability distribution.The auto-handling converts single-channel logits to two-channel probabilities via
torch.cat([1 - y_pred, y_pred], dim=1)after sigmoid. This assumesy_predafter sigmoid is the foreground probability. Confirm this matches user expectations and aligns with the rest of MONAI's binary segmentation conventions.
91-99: LGTM - probability conversion logic is correct.The
is_already_probflag prevents double conversion for single-channel inputs, and clamping protects against numerical instability.
288-291: Shape validation logic is sound.The check for binary logits case (
y_pred.shape[1] == 1 and not self.use_softmax) correctly allows shape mismatch when appropriate.
Signed-off-by: ytl0623 <[email protected]>
05dac9e to
cbed38d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/losses/unified_focal_loss.py (1)
240-240: Remove unusednum_classesparameter.The
num_classesparameter is stored but never used in the implementation. It's not passed to internal loss instances and doesn't affect behavior. Either use it to validate inputs or remove it from the interface.
♻️ Duplicate comments (3)
monai/losses/unified_focal_loss.py (3)
83-83: Addstacklevel=2to warning for proper caller attribution.The warning call should specify
stacklevel=2to point to the caller's code rather than this line.🔎 Proposed fix
- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
188-188: Addstacklevel=2to warning for proper caller attribution.The warning call should specify
stacklevel=2to point to the caller's code.🔎 Proposed fix
- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
293-296: Handle or rejectreduction=NONEto prevent shape mismatch.When
reduction=NONE, AsymmetricFocalLoss returns shape(B, H, W, [D])(per-pixel) while AsymmetricFocalTverskyLoss returns shape(B, C)(per-class). Line 296's addition will fail. Either document and reject NONE reduction with a runtime check, or ensure both losses return compatible shapes.🔎 Proposed fix - reject NONE reduction
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """ Args: y_pred: Prediction logits. Shape: (B, C, H, W, [D]). Supports binary (C=1 or C=2) and multi-class (C>2) segmentation. y_true: Ground truth labels. Shape should match y_pred (or be indices if to_onehot_y is True). """ + if self.reduction == LossReduction.NONE.value: + raise ValueError("AsymmetricUnifiedFocalLoss does not support reduction='none' due to incompatible output shapes from component losses.") + if y_pred.shape != y_true.shape:
🧹 Nitpick comments (2)
monai/losses/unified_focal_loss.py (2)
51-53: Clarify sigmoid behavior for multi-channel inputs.The docstring states sigmoid is used "for binary/multi-label" when
use_softmax=False, but doesn't clarify that sigmoid is applied independently to each channel in multi-channel cases. This differs from binary-only behavior where background channel is constructed. Consider adding: "For multi-channel inputs, sigmoid is applied per-channel independently (multi-label)."
159-159: Enhanceuse_softmaxdocumentation with usage guidance.Similar to AsymmetricFocalTverskyLoss, the docstring should clarify when to use softmax (mutually exclusive classes) vs. sigmoid (multi-label/overlapping classes). Consider adding: "Use
Truefor mutually exclusive multi-class segmentation,Falsefor binary or multi-label scenarios."
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
Signed-off-by: ytl0623 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (1)
monai/losses/unified_focal_loss.py (1)
175-200: Same redundant probability conversion issue as AsymmetricFocalTverskyLoss.Lines 175-182 duplicate the auto-handle logic from AsymmetricFocalTverskyLoss with the same flaw: single-channel always uses sigmoid regardless of
use_softmaxsetting.Apply the same fix as suggested for AsymmetricFocalTverskyLoss to respect
use_softmaxafter expanding to 2-channel.
🧹 Nitpick comments (3)
monai/losses/unified_focal_loss.py (3)
213-226: Reduction logic duplicates default fallback.Lines 220-226 explicitly handle MEAN/SUM/NONE, then line 226 returns
torch.mean(total_loss)as fallback. This fallback is unreachable if LossReduction enum is exhaustive.🔎 Simplification
Remove redundant fallback or add a warning if an unknown reduction is encountered:
if self.reduction == LossReduction.MEAN.value: return torch.mean(total_loss) if self.reduction == LossReduction.SUM.value: return torch.sum(total_loss) if self.reduction == LossReduction.NONE.value: return total_loss - return torch.mean(total_loss) + raise ValueError(f"Unsupported reduction: {self.reduction}")Same applies to AsymmetricFocalTverskyLoss lines 122-129.
288-291: Shape validation allows mismatch only for binary logits, but one-hot conversion happens downstream.Lines 288-291 permit shape mismatch if
is_binary_logits(C=1 with sigmoid) or ifto_onehot_y=True. However, the internal losses perform one-hot conversion independently. Ify_truehas mismatched shape andto_onehot_y=False, the internal losses will raise ValueError at their shape checks (lines 89, 194).This validation is redundant; the internal losses already enforce shape compatibility.
🔎 Simplification
Remove this check and let internal losses handle validation:
- if y_pred.shape != y_true.shape: - is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax - if not self.to_onehot_y and not is_binary_logits: - raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") -Or add a comment explaining why this pre-validation is needed.
89-89: Long exception messages flagged by static analysis.Lines 89, 194, and 291 embed long f-string messages directly in ValueError. Ruff (TRY003) suggests defining exception classes or message constants for long messages.
For consistency with MONAI conventions, verify if other loss modules use inline messages or constants. If this pattern is acceptable project-wide, ignore the hint.
Also applies to: 194-194, 291-291
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py
89-89: Avoid specifying long messages outside the exception class
(TRY003)
194-194: Avoid specifying long messages outside the exception class
(TRY003)
291-291: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (2)
monai/losses/unified_focal_loss.py (2)
266-279: Gamma parameters have opposite semantics in AsymmetricFocalLoss vs AsymmetricFocalTverskyLoss.AsymmetricFocalLoss uses gamma directly:
torch.pow(1 - y_pred, gamma), while AsymmetricFocalTverskyLoss uses its reciprocal:torch.pow(1 - dice_class, 1/gamma). Per the paper, Focal Tversky's optimal gamma=4/3 enhances loss (contrary to Focal loss which suppresses). Passing the same gamma=0.5 to both produces mismatched behaviors and may not match the paper's unified formulation intent.
114-115: The focal modulation formula is correct—it properly focuses on hard examples, not the reverse.With
gamma = 0.75(default), the exponent1/gamma = 1.333 > 1, which makes hard examples (low Dice values) contribute more to the loss than easy examples, not less. When you raise numbers to a power greater than 1, small values (easy examples where1 - diceis small) decrease more than large values (hard examples where1 - diceis large), so easy examples are down-weighted relative to hard examples. This is standard focal behavior and matches the docstring: "focal exponent value to down-weight easy foreground examples." The Unified Focal Loss specifies γ < 1 increases focusing on harder examples, and MONAI's reparameterization using1/gammaas the exponent achieves this—gamma = 0.75 yields exponent 1.333, which focuses on hard examples correctly.Likely an incorrect or invalid review comment.
| # Auto-handle single channel input (binary segmentation case) | ||
| if y_pred.shape[1] == 1 and not self.use_softmax: | ||
| y_pred = torch.sigmoid(y_pred) | ||
| y_pred = torch.cat([1 - y_pred, y_pred], dim=1) | ||
| is_already_prob = True | ||
| if y_true.shape[1] == 1: | ||
| y_true = one_hot(y_true, num_classes=2) | ||
| else: | ||
| is_already_prob = False | ||
|
|
||
| n_pred_ch = y_pred.shape[1] | ||
|
|
||
| if self.to_onehot_y: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) | ||
| else: | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||
| if y_true.shape[1] != n_pred_ch: | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||
|
|
||
| if y_true.shape != y_pred.shape: | ||
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") | ||
|
|
||
| # clip the prediction to avoid NaN | ||
| # Convert logits to probabilities if not already done | ||
| if not is_already_prob: | ||
| if self.use_softmax: | ||
| y_pred = torch.softmax(y_pred, dim=1) | ||
| else: | ||
| y_pred = torch.sigmoid(y_pred) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant probability conversion when auto-handling single-channel input.
Lines 69-77 apply sigmoid and set is_already_prob=True, then lines 91-96 skip conversion when is_already_prob. However, if use_softmax=True with a single-channel input, the code applies sigmoid at line 71, but softmax would be more appropriate for the 2-channel representation created at line 72.
The logic forces sigmoid for single-channel regardless of use_softmax, which may not align with user intent.
🔎 Suggested approach
Apply the activation chosen by use_softmax after expanding to 2-channel:
- if y_pred.shape[1] == 1 and not self.use_softmax:
+ if y_pred.shape[1] == 1:
y_pred = torch.sigmoid(y_pred)
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
- is_already_prob = True
+ if self.use_softmax:
+ y_pred = torch.softmax(y_pred, dim=1)
+ is_already_prob = True
if y_true.shape[1] == 1:
y_true = one_hot(y_true, num_classes=2)
else:
is_already_prob = FalseOr document that single-channel always uses sigmoid regardless of use_softmax.
🧰 Tools
🪛 Ruff (0.14.10)
89-89: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In monai/losses/unified_focal_loss.py around lines 69-96, the code currently
always applies sigmoid when auto-handling a single-channel prediction, which is
wrong when use_softmax=True; instead, for single-channel inputs defer
activation: if y_pred has 1 channel and use_softmax is True, expand the logits
to two channels (create a compatible two-channel logits tensor, e.g., by
concatenating the negated and original logits or an appropriate pair) and leave
is_already_prob=False so the later softmax branch runs; if use_softmax is False,
perform the sigmoid, concatenate probabilities to two channels and set
is_already_prob=True; ensure subsequent one-hot conversion and shape checks
remain unchanged.
| # Calculate losses separately for each class | ||
| # Background: Standard Dice Loss | ||
| back_dice = 1 - dice_class[:, 0] | ||
| fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma) | ||
|
|
||
| # Average class scores | ||
| loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) | ||
| return loss | ||
| # Foreground: Focal Tversky Loss | ||
| fore_dice = torch.pow(1 - dice_class[:, 1:], 1 / self.gamma) | ||
|
|
||
| # Concatenate background and foreground losses | ||
| # back_dice needs unsqueeze to match dimensions: (B,) -> (B, 1) | ||
| all_losses = torch.cat([back_dice.unsqueeze(1), fore_dice], dim=1) | ||
|
|
||
| # Apply reduction | ||
| if self.reduction == LossReduction.MEAN.value: | ||
| return torch.mean(all_losses) | ||
| if self.reduction == LossReduction.SUM.value: | ||
| return torch.sum(all_losses) | ||
| if self.reduction == LossReduction.NONE.value: | ||
| return all_losses | ||
|
|
||
| return torch.mean(all_losses) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return shape with reduction=NONE is (B, C), not per-pixel.
With reduction=NONE, line 127 returns all_losses with shape (B, C) (per-class). This differs from AsymmetricFocalLoss, which returns per-pixel shape. AsymmetricUnifiedFocalLoss combines both at line 296, causing a shape mismatch when reduction=NONE.
Past review comment flagged this but it remains unresolved.
🔎 Recommended fix
Add runtime guard in AsymmetricUnifiedFocalLoss:
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
+ if self.reduction == LossReduction.NONE.value:
+ raise ValueError("AsymmetricUnifiedFocalLoss does not support reduction='none' due to incompatible output shapes from constituent losses.")Update docstring to document this limitation and add test coverage.
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In monai/losses/unified_focal_loss.py around lines 110 to 129, with
reduction=NONE the function returns all_losses shaped (B, C) (per-class) which
mismatches the per-pixel shape expected by AsymmetricUnifiedFocalLoss; add a
runtime guard to detect reduction==LossReduction.NONE and either (preferable per
the review) raise a clear ValueError explaining that AsymmetricUnifiedFocalLoss
cannot be used with reduction=NONE because this implementation returns per-class
losses, or alternatively change the reduction=NONE branch to compute and return
per-pixel losses to match AsymmetricFocalLoss; update the function docstring to
document this limitation (or the new behavior) and add unit tests that assert
the ValueError is raised (or validate the corrected per-pixel output shape) to
prevent regressions.
Signed-off-by: ytl0623 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
monai/losses/unified_focal_loss.py (2)
176-183: Duplicate issue: single-channel incompatible withuse_softmax=True.Same issue as AsymmetricFocalTverskyLoss. Apply the same fix to handle single-channel inputs with
use_softmax=True.
69-78: Single-channel input incompatible withuse_softmax=True.Line 70's condition
and not self.use_softmaxskips auto-expansion whenuse_softmax=True. A 1-channel tensor passed totorch.softmaxat line 95 produces valid probabilities but remains 1-channel, causing shape mismatches downstream. Either document that single-channel requiresuse_softmax=False, or expand single-channel logits to 2-class before softmax.🔎 Proposed fix
- if y_pred.shape[1] == 1 and not self.use_softmax: + if y_pred.shape[1] == 1: + if self.use_softmax: + # Expand to 2-class logits for softmax: [logit] -> [-logit, logit] + y_pred = torch.cat([-y_pred, y_pred], dim=1) + is_already_prob = False + else: - y_pred = torch.sigmoid(y_pred) - y_pred = torch.cat([1 - y_pred, y_pred], dim=1) - is_already_prob = True + y_pred = torch.sigmoid(y_pred) + y_pred = torch.cat([1 - y_pred, y_pred], dim=1) + is_already_prob = True if y_true.shape[1] == 1: y_true = one_hot(y_true, num_classes=2)
🧹 Nitpick comments (2)
monai/losses/unified_focal_loss.py (2)
90-90: Consider extracting exception message to a constant.Ruff flags the inline message. Extract to a module-level constant if this pattern recurs.
Based on static analysis hint.
195-195: Consider extracting exception message to a constant.Same static analysis hint as line 90.
Based on static analysis hint.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/unified_focal_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/unified_focal_loss.py
🧬 Code graph analysis (1)
monai/losses/unified_focal_loss.py (2)
monai/utils/enums.py (1)
LossReduction(253-264)monai/networks/utils.py (1)
one_hot(170-220)
🪛 Ruff (0.14.10)
monai/losses/unified_focal_loss.py
90-90: Avoid specifying long messages outside the exception class
(TRY003)
195-195: Avoid specifying long messages outside the exception class
(TRY003)
290-290: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.10)
🔇 Additional comments (6)
monai/losses/unified_focal_loss.py (6)
92-100: Probability conversion logic is correct.The conditional activation and epsilon clamping properly handle both softmax and sigmoid cases while preventing redundant conversions.
111-130: Per-class loss computation handles edge cases correctly.The positive exponent
1/self.gammaat line 116 avoids infinity when dice approaches 1.0, addressing the numerical stability concern from past reviews. The(B, C)output shape withreduction=NONEis intentional for the wrapper's spatial averaging.
197-203: Probability conversion and clamping are correct.Consistent with AsymmetricFocalTverskyLoss implementation.
207-225: Asymmetric focal loss computation is mathematically sound.Background focal modulation and foreground cross-entropy are correctly weighted and concatenated. The per-pixel
(B, C, H, W, [D])output withreduction=NONEproperly complements AsymmetricFocalTverskyLoss.
292-300: Shape alignment correctly resolves reduction=NONE mismatch.Lines 295-297 average the per-pixel focal loss spatially to produce
(B, C), matching the Tversky loss shape. This resolves the shape mismatch concern from past reviews. The weighted combination at line 300 now operates on compatible tensors.
303-311: Final reduction logic is clean and correct.All three reduction modes are handled explicitly with a fallback to MEAN. The implementation properly applies reduction to the combined
(B, C)loss tensor.
| if y_pred.shape != y_true.shape: | ||
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") | ||
|
|
||
| if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: | ||
| raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") | ||
|
|
||
| if y_pred.shape[1] == 1: | ||
| y_pred = one_hot(y_pred, num_classes=self.num_classes) | ||
| y_true = one_hot(y_true, num_classes=self.num_classes) | ||
|
|
||
| if torch.max(y_true) != self.num_classes - 1: | ||
| raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}") | ||
|
|
||
| n_pred_ch = y_pred.shape[1] | ||
| if self.to_onehot_y: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||
| else: | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||
| is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax | ||
| if not self.to_onehot_y and not is_binary_logits: | ||
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Shape validation logic needs clarification.
Line 288 defines is_binary_logits but only uses it at line 289. The condition allows shape mismatch for binary logits or when to_onehot_y=True, but the logic is unclear. Document when shape mismatch is expected vs an error.
🔎 Add clarifying comment
if y_pred.shape != y_true.shape:
+ # Allow mismatch when: (1) binary logits will be auto-expanded, or (2) y_true will be one-hot encoded
is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax
if not self.to_onehot_y and not is_binary_logits:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")Based on static analysis hint.
🧰 Tools
🪛 Ruff (0.14.10)
290-290: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In monai/losses/unified_focal_loss.py around lines 287-290, the current
shape-check treats a mismatch as OK for either to_onehot_y=True or for "binary
logits" (is_binary_logits) but has no explanation or clear rule for allowed
shapes; add a concise clarifying comment above this block that states the exact
allowed cases (e.g., when to_onehot_y=True we expect class-indexed y_true shapes
different from y_pred, and for binary logits y_pred has a channel dim of 1 while
y_true may omit that channel so a shape mismatch is acceptable), and tighten the
condition by explicitly checking the common binary case (y_true ndim equals
y_pred ndim - 1) before allowing the mismatch so behaviour is unambiguous.
Fixes #8603
Description
Refactors
AsymmetricUnifiedFocalLossand its sub-components (AsymmetricFocalLoss,AsymmetricFocalTverskyLoss) to extend support from Binary-only to Multi-class segmentation, while also fixing mathematical logic errors and parameter passing bugs.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.