Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 221 additions & 53 deletions monai/losses/cldice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,17 @@

from __future__ import annotations

import warnings
from collections.abc import Callable

import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss

from monai.losses.dice import DiceLoss
from monai.networks import one_hot
from monai.utils import LossReduction


def soft_erode(img: torch.Tensor) -> torch.Tensor: # type: ignore
"""
Expand Down Expand Up @@ -92,26 +99,6 @@ def soft_skel(img: torch.Tensor, iter_: int) -> torch.Tensor:
return skel


def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor, smooth: float = 1.0) -> torch.Tensor:
"""
Function to compute soft dice loss

Adapted from:
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L22

Args:
y_true: the shape should be BCH(WD)
y_pred: the shape should be BCH(WD)

Returns:
dice loss
"""
intersection = torch.sum((y_true * y_pred)[:, 1:, ...])
coeff = (2.0 * intersection + smooth) / (torch.sum(y_true[:, 1:, ...]) + torch.sum(y_pred[:, 1:, ...]) + smooth)
soft_dice: torch.Tensor = 1.0 - coeff
return soft_dice


class SoftclDiceLoss(_Loss):
"""
Compute the Soft clDice loss defined in:
Expand All @@ -121,64 +108,245 @@ class SoftclDiceLoss(_Loss):

Adapted from:
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L7

The data `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]).
Note that axis N of `input` is expected to be logits or probabilities for each class, if passing logits as input,
must set `sigmoid=True` or `softmax=True`, or specifying `other_act`. And the same axis of `target`
can be 1 or N (one-hot format).

"""

def __init__(self, iter_: int = 3, smooth: float = 1.0) -> None:
def __init__(
self,
iter_: int = 3,
smooth: float = 1.0,
include_background: bool = True,
to_onehot_y: bool = False,
sigmoid: bool = False,
softmax: bool = False,
other_act: Callable | None = None,
reduction: LossReduction | str = LossReduction.MEAN,
) -> None:
"""
Args:
iter_: Number of iterations for skeletonization
smooth: Smoothing parameter
iter_: Number of iterations for skeletonization.
smooth: Smoothing parameter to avoid division by zero. Defaults to 1.0.
include_background: if False, channel index 0 (background category) is excluded from the calculation.
if the non-background segmentations are small compared to the total image size they can get overwhelmed
by the signal from the background so excluding it in such cases helps convergence.
to_onehot_y: whether to convert the ``target`` into the one-hot format,
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
sigmoid: if True, apply a sigmoid function to the prediction.
softmax: if True, apply a softmax function to the prediction.
other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
``other_act = torch.tanh``.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.

- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.

Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
Incompatible values.

"""
super().__init__()
super().__init__(reduction=LossReduction(reduction).value)
if other_act is not None and not callable(other_act):
raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
if int(sigmoid) + int(softmax) + int(other_act is not None) > 1:
raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].")
if smooth <= 0:
raise ValueError(f"smooth must be a positive value but got {smooth}.")
self.iter = iter_
Comment on lines +161 to 163
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing validation for iter_ parameter.

smooth is validated but iter_ is not. A non-positive value would produce incorrect skeletonization.

Proposed fix
         if smooth <= 0:
             raise ValueError(f"smooth must be a positive value but got {smooth}.")
+        if iter_ < 0:
+            raise ValueError(f"iter_ must be a non-negative integer but got {iter_}.")
         self.iter = iter_
🧰 Tools
🪛 Ruff (0.14.11)

162-162: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In `@monai/losses/cldice.py` around lines 161 - 163, The code currently validates
smooth but not the skeletonization iteration count iter_; add validation in the
clDice constructor (or the function where self.iter is set) to ensure iter_ is a
positive integer: check that iter_ is an int (or castable to int) and greater
than 0, and raise a ValueError with a clear message if not; then assign
self.iter = int(iter_) so downstream skeletonize/skel operations use a safe
positive integer (refer to the self.iter assignment and the iter_ parameter in
the clDice class/constructor).

self.smooth = smooth
self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.sigmoid = sigmoid
self.softmax = softmax
self.other_act = other_act

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the shape should be BNH[WD], where N is the number of classes.
target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.

Raises:
AssertionError: When input and target (after one hot transform if set)
have different shapes.

"""
n_pred_ch = input.shape[1]

if self.sigmoid:
input = torch.sigmoid(input)

if self.softmax:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `softmax=True` ignored.")
else:
input = torch.softmax(input, dim=1)

def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
skel_pred = soft_skel(y_pred, self.iter)
skel_true = soft_skel(y_true, self.iter)
tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / (
torch.sum(skel_pred[:, 1:, ...]) + self.smooth
if self.other_act is not None:
input = self.other_act(input)

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
target = one_hot(target, num_classes=n_pred_ch)

if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
else:
target = target[:, 1:]
input = input[:, 1:]

if target.shape != input.shape:
raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")

skel_pred = soft_skel(input, self.iter)
skel_true = soft_skel(target, self.iter)

# Compute per-batch clDice by reducing over channel and spatial dimensions
# reduce_axis includes all dimensions except batch (dim 0)
reduce_axis: list[int] = list(range(1, len(input.shape)))

tprec = (torch.sum(torch.multiply(skel_pred, target), dim=reduce_axis) + self.smooth) / (
torch.sum(skel_pred, dim=reduce_axis) + self.smooth
)
tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / (
torch.sum(skel_true[:, 1:, ...]) + self.smooth
tsens = (torch.sum(torch.multiply(skel_true, input), dim=reduce_axis) + self.smooth) / (
torch.sum(skel_true, dim=reduce_axis) + self.smooth
)
cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens)
cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens + 1e-8)

# Apply reduction
if self.reduction == LossReduction.MEAN.value:
cl_dice = torch.mean(cl_dice)
elif self.reduction == LossReduction.SUM.value:
cl_dice = torch.sum(cl_dice)
elif self.reduction == LossReduction.NONE.value:
pass # keep per-batch values
else:
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

return cl_dice


class SoftDiceclDiceLoss(_Loss):
"""
Compute the Soft clDice loss defined in:
Compute both Dice loss and clDice loss, and return the weighted sum of these two losses.
The details of Dice loss is shown in ``monai.losses.DiceLoss``.
The details of clDice loss is shown in ``monai.losses.SoftclDiceLoss``.

Adapted from:
Shit et al. (2021) clDice -- A Novel Topology-Preserving Loss Function
for Tubular Structure Segmentation. (https://arxiv.org/abs/2003.07311)

Adapted from:
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L38
"""

def __init__(self, iter_: int = 3, alpha: float = 0.5, smooth: float = 1.0) -> None:
def __init__(
self,
iter_: int = 3,
alpha: float = 0.5,
smooth: float = 1.0,
include_background: bool = True,
to_onehot_y: bool = False,
sigmoid: bool = False,
softmax: bool = False,
other_act: Callable | None = None,
reduction: LossReduction | str = LossReduction.MEAN,
) -> None:
"""
Args:
iter_: Number of iterations for skeletonization
smooth: Smoothing parameter
alpha: Weighing factor for cldice
iter_: Number of iterations for skeletonization, used by clDice.
alpha: Weighing factor for cldice component. Total loss = (1 - alpha) * dice + alpha * cldice.
Defaults to 0.5.
smooth: Smoothing parameter to avoid division by zero, used by both Dice and clDice. Defaults to 1.0.
include_background: if False, channel index 0 (background category) is excluded from the calculation.
if the non-background segmentations are small compared to the total image size they can get overwhelmed
by the signal from the background so excluding it in such cases helps convergence.
to_onehot_y: whether to convert the ``target`` into the one-hot format,
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
sigmoid: if True, apply a sigmoid function to the prediction.
softmax: if True, apply a softmax function to the prediction.
other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
``other_act = torch.tanh``.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.

- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.

Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
Incompatible values.

"""
super().__init__()
self.iter = iter_
self.smooth = smooth
self.alpha = alpha

def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
dice = soft_dice(y_true, y_pred, self.smooth)
skel_pred = soft_skel(y_pred, self.iter)
skel_true = soft_skel(y_true, self.iter)
tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / (
torch.sum(skel_pred[:, 1:, ...]) + self.smooth
if smooth <= 0:
raise ValueError(f"smooth must be a positive value but got {smooth}.")
self.dice = DiceLoss(
include_background=include_background,
to_onehot_y=False,
sigmoid=sigmoid,
softmax=softmax,
other_act=other_act,
reduction=reduction,
smooth_nr=smooth,
smooth_dr=smooth,
)
tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / (
torch.sum(skel_true[:, 1:, ...]) + self.smooth
self.cldice = SoftclDiceLoss(
iter_=iter_,
smooth=smooth,
include_background=include_background,
to_onehot_y=False,
sigmoid=sigmoid,
softmax=softmax,
other_act=other_act,
reduction=reduction,
)
cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens)
total_loss: torch.Tensor = (1.0 - self.alpha) * dice + self.alpha * cl_dice
self.alpha = alpha
self.to_onehot_y = to_onehot_y

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the shape should be BNH[WD], where N is the number of classes.
target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.

Raises:
ValueError: When number of dimensions for input and target are different.
ValueError: When number of channels for target is neither 1 nor the same as input.

"""
if input.dim() != target.dim():
raise ValueError(
"the number of dimensions for input and target should be the same, "
f"got shape {input.shape} and {target.shape}."
)

if target.shape[1] != 1 and target.shape[1] != input.shape[1]:
raise ValueError(
"number of channels for target is neither 1 nor the same as input, "
f"got shape {input.shape} and {target.shape}."
)

if self.to_onehot_y:
n_pred_ch = input.shape[1]
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
target = one_hot(target, num_classes=n_pred_ch)

dice_loss = self.dice(input, target)
cldice_loss = self.cldice(input, target)
total_loss: torch.Tensor = (1.0 - self.alpha) * dice_loss + self.alpha * cldice_loss

return total_loss
Loading
Loading