Skip to content

Commit 6532e35

Browse files
committed
add dice_loss
1 parent d710f3d commit 6532e35

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

torchvision/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .roi_align import roi_align, RoIAlign
2727
from .roi_pool import roi_pool, RoIPool
2828
from .stochastic_depth import stochastic_depth, StochasticDepth
29+
from .dice_loss import dice_loss
2930

3031
_register_custom_op()
3132

torchvision/ops/dice_loss.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import torch
2+
import torch.nn.functional as F
3+
4+
def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "none", eps: float = 1e-8) -> torch.Tensor:
5+
"""Criterion that computes Sørensen-Dice Coefficient loss.
6+
7+
We compute the Sørensen-Dice Coefficient as follows:
8+
9+
.. math::
10+
11+
\text{Dice}(x, class) = \frac{2 |X \cap Y|}{|X| + |Y|}
12+
13+
Where:
14+
- :math:`X` expects to be the scores of each class.
15+
- :math:`Y` expects to be thess tensor with the class labels.
16+
17+
the loss, is finally computed as:
18+
19+
.. math::
20+
21+
\text{loss}(x, class) = 1 - \text{Dice}(x, class)
22+
23+
Args:
24+
inputs: (Tensor): A float tensor of arbitrary shape.
25+
The predictions for each example.
26+
targets: (Tensor): A float tensor with the same shape as inputs. Stores the binary
27+
classification label for each element in inputs
28+
(0 for the negative class and 1 for the positive class).
29+
eps: (float, optional): Scalar to enforce numerical stabiliy.
30+
reduction (string, optional): ``'none'`` | ``'mean'`` | ``'sum'``
31+
``'none'``: No reduction will be applied to the output.
32+
``'mean'``: The output will be averaged.
33+
``'sum'``: The output will be summed. Default: ``'none'``.
34+
35+
Return:
36+
Tensor: Loss tensor with the reduction option applied.
37+
"""
38+
if not isinstance(inputs, torch.Tensor):
39+
raise TypeError(f"Input type is not a torch.Tensor. Got {type(inputs)}")
40+
41+
if not len(inputs.shape) == 4:
42+
raise ValueError(f"Invalid input shape, we expect BxNxHxW. Got: {inputs.shape}")
43+
44+
if not inputs.shape[-2:] == targets.shape[-2:]:
45+
raise ValueError(f"input and target shapes must be the same. Got: {inputs.shape} and {targets.shape}")
46+
47+
if not inputs.device == targets.device:
48+
raise ValueError(f"input and target must be in the same device. Got: {inputs.device} and {targets.device}")
49+
50+
# compute softmax over the classes axis
51+
p = F.softmax(inputs, dim=1)
52+
53+
# compute the actual dice score
54+
dims = (1, 2, 3)
55+
intersection = torch.sum(p * targets, dims)
56+
cardinality = torch.sum(p + targets, dims)
57+
58+
dice_score = 2.0 * intersection / (cardinality + eps)
59+
60+
loss = 1.0 - dice_score
61+
62+
# Check reduction option and return loss accordingly
63+
if reduction == "none":
64+
pass
65+
elif reduction == "mean":
66+
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
67+
elif reduction == "sum":
68+
loss = loss.sum()
69+
else:
70+
raise ValueError(
71+
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
72+
)
73+
74+
return loss

0 commit comments

Comments
 (0)