1
+ import torch
2
+ import torch .nn .functional as F
3
+
4
+ def dice_loss (inputs : torch .Tensor , targets : torch .Tensor , reduction : str = "none" , eps : float = 1e-8 ) -> torch .Tensor :
5
+ """Criterion that computes Sørensen-Dice Coefficient loss.
6
+
7
+ We compute the Sørensen-Dice Coefficient as follows:
8
+
9
+ .. math::
10
+
11
+ \t ext{Dice}(x, class) = \f rac{2 |X \cap Y|}{|X| + |Y|}
12
+
13
+ Where:
14
+ - :math:`X` expects to be the scores of each class.
15
+ - :math:`Y` expects to be thess tensor with the class labels.
16
+
17
+ the loss, is finally computed as:
18
+
19
+ .. math::
20
+
21
+ \t ext{loss}(x, class) = 1 - \t ext{Dice}(x, class)
22
+
23
+ Args:
24
+ inputs: (Tensor): A float tensor of arbitrary shape.
25
+ The predictions for each example.
26
+ targets: (Tensor): A float tensor with the same shape as inputs. Stores the binary
27
+ classification label for each element in inputs
28
+ (0 for the negative class and 1 for the positive class).
29
+ eps: (float, optional): Scalar to enforce numerical stabiliy.
30
+ reduction (string, optional): ``'none'`` | ``'mean'`` | ``'sum'``
31
+ ``'none'``: No reduction will be applied to the output.
32
+ ``'mean'``: The output will be averaged.
33
+ ``'sum'``: The output will be summed. Default: ``'none'``.
34
+
35
+ Return:
36
+ Tensor: Loss tensor with the reduction option applied.
37
+ """
38
+ if not isinstance (inputs , torch .Tensor ):
39
+ raise TypeError (f"Input type is not a torch.Tensor. Got { type (inputs )} " )
40
+
41
+ if not len (inputs .shape ) == 4 :
42
+ raise ValueError (f"Invalid input shape, we expect BxNxHxW. Got: { inputs .shape } " )
43
+
44
+ if not inputs .shape [- 2 :] == targets .shape [- 2 :]:
45
+ raise ValueError (f"input and target shapes must be the same. Got: { inputs .shape } and { targets .shape } " )
46
+
47
+ if not inputs .device == targets .device :
48
+ raise ValueError (f"input and target must be in the same device. Got: { inputs .device } and { targets .device } " )
49
+
50
+ # compute softmax over the classes axis
51
+ p = F .softmax (inputs , dim = 1 )
52
+
53
+ # compute the actual dice score
54
+ dims = (1 , 2 , 3 )
55
+ intersection = torch .sum (p * targets , dims )
56
+ cardinality = torch .sum (p + targets , dims )
57
+
58
+ dice_score = 2.0 * intersection / (cardinality + eps )
59
+
60
+ loss = 1.0 - dice_score
61
+
62
+ # Check reduction option and return loss accordingly
63
+ if reduction == "none" :
64
+ pass
65
+ elif reduction == "mean" :
66
+ loss = loss .mean () if loss .numel () > 0 else 0.0 * loss .sum ()
67
+ elif reduction == "sum" :
68
+ loss = loss .sum ()
69
+ else :
70
+ raise ValueError (
71
+ f"Invalid Value for arg 'reduction': '{ reduction } \n Supported reduction modes: 'none', 'mean', 'sum'"
72
+ )
73
+
74
+ return loss
0 commit comments