@@ -28,10 +28,10 @@ class FocalLoss(_Loss):
28
28
FocalLoss is an extension of BCEWithLogitsLoss that down-weights loss from
29
29
high confidence correct predictions.
30
30
31
- Reimplementation of the Focal Loss (with a build-in sigmoid activation) described in:
31
+ Reimplementation of the Focal Loss described in:
32
32
33
- - "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017
34
- - "AnatomyNet: Deep learning for fast and fully automated whole‐ volume segmentation of head and neck anatomy",
33
+ - [ "Focal Loss for Dense Object Detection"](https://arxiv.org/abs/1708.02002) , T. Lin et al., ICCV 2017
34
+ - "AnatomyNet: Deep learning for fast and fully automated whole- volume segmentation of head and neck anatomy",
35
35
Zhu et al., Medical Physics 2018
36
36
37
37
Example:
@@ -70,19 +70,23 @@ def __init__(
70
70
include_background : bool = True ,
71
71
to_onehot_y : bool = False ,
72
72
gamma : float = 2.0 ,
73
+ alpha : float | None = None ,
73
74
weight : Sequence [float ] | float | int | torch .Tensor | None = None ,
74
75
reduction : LossReduction | str = LossReduction .MEAN ,
76
+ use_softmax : bool = False ,
75
77
) -> None :
76
78
"""
77
79
Args:
78
- include_background: if False, channel index 0 (background category) is excluded from the calculation.
79
- to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
80
- gamma: value of the exponent gamma in the definition of the Focal loss.
80
+ include_background: if False, channel index 0 (background category) is excluded from the loss calculation.
81
+ If False, `alpha` is invalid when using softmax.
82
+ to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False.
83
+ gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
84
+ alpha: value of the alpha in the definition of the alpha-balanced Focal loss.
85
+ The value should be in [0, 1]. Defaults to None.
81
86
weight: weights to apply to the voxels of each class. If None no weights are applied.
82
- This corresponds to the weights `\a lpha` in [1].
83
87
The input can be a single value (same weight for all classes), a sequence of values (the length
84
- of the sequence should be the same as the number of classes, if not ``include_background``, the
85
- number should not include class 0).
88
+ of the sequence should be the same as the number of classes. If not ``include_background``,
89
+ the number of classes should not include the background category class 0).
86
90
The value/values should be no less than 0. Defaults to None.
87
91
reduction: {``"none"``, ``"mean"``, ``"sum"``}
88
92
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
@@ -91,6 +95,9 @@ def __init__(
91
95
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
92
96
- ``"sum"``: the output will be summed.
93
97
98
+ use_softmax: whether to use softmax to transform the original logits into probabilities.
99
+ If True, softmax is used. If False, sigmoid is used. Defaults to False.
100
+
94
101
Example:
95
102
>>> import torch
96
103
>>> from monai.losses import FocalLoss
@@ -103,14 +110,16 @@ def __init__(
103
110
self .include_background = include_background
104
111
self .to_onehot_y = to_onehot_y
105
112
self .gamma = gamma
106
- self .weight : Sequence [float ] | float | int | torch .Tensor | None = weight
113
+ self .alpha = alpha
114
+ self .weight = weight
115
+ self .use_softmax = use_softmax
107
116
108
117
def forward (self , input : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
109
118
"""
110
119
Args:
111
120
input: the shape should be BNH[WD], where N is the number of classes.
112
121
The input should be the original logits since it will be transformed by
113
- a sigmoid in the forward function.
122
+ a sigmoid/softmax in the forward function.
114
123
target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
115
124
116
125
Raises:
@@ -141,63 +150,106 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
141
150
if target .shape != input .shape :
142
151
raise ValueError (f"ground truth has different shape ({ target .shape } ) from input ({ input .shape } )" )
143
152
144
- i = input
145
- t = target
146
-
147
- # Change the shape of input and target to B x N x num_voxels.
148
- b , n = t .shape [:2 ]
149
- i = i .reshape (b , n , - 1 )
150
- t = t .reshape (b , n , - 1 )
151
-
152
- # computing binary cross entropy with logits
153
- # see also https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/native/Loss.cpp#L231
154
- max_val = (- i ).clamp (min = 0 )
155
- ce = i - i * t + max_val + ((- max_val ).exp () + (- i - max_val ).exp ()).log ()
153
+ loss : Optional [torch .Tensor ] = None
154
+ input = input .float ()
155
+ target = target .float ()
156
+ if self .use_softmax :
157
+ if not self .include_background and self .alpha is not None :
158
+ self .alpha = None
159
+ warnings .warn ("`include_background=False`, `alpha` ignored when using softmax." )
160
+ loss = softmax_focal_loss (input , target , self .gamma , self .alpha )
161
+ else :
162
+ loss = sigmoid_focal_loss (input , target , self .gamma , self .alpha )
156
163
157
164
if self .weight is not None :
165
+ # make sure the lengths of weights are equal to the number of classes
158
166
class_weight : Optional [torch .Tensor ] = None
167
+ num_of_classes = target .shape [1 ]
159
168
if isinstance (self .weight , (float , int )):
160
- class_weight = torch .as_tensor ([self .weight ] * i . size ( 1 ) )
169
+ class_weight = torch .as_tensor ([self .weight ] * num_of_classes )
161
170
else :
162
171
class_weight = torch .as_tensor (self .weight )
163
- if class_weight .size ( 0 ) != i . size ( 1 ) :
172
+ if class_weight .shape [ 0 ] != num_of_classes :
164
173
raise ValueError (
165
- "the length of the weight sequence should be the same as the number of classes. "
166
- + "If `include_background=False`, the number should not include class 0."
174
+ """the length of the `weight` sequence should be the same as the number of classes.
175
+ If `include_background=False`, the weight should not include
176
+ the background category class 0."""
167
177
)
168
178
if class_weight .min () < 0 :
169
- raise ValueError ("the value/values of weights should be no less than 0." )
170
- class_weight = class_weight .to (i )
171
- # Convert the weight to a map in which each voxel
172
- # has the weight associated with the ground-truth label
173
- # associated with this voxel in target.
174
- at = class_weight [None , :, None ] # N => 1,N,1
175
- at = at .expand ((t .size (0 ), - 1 , t .size (2 ))) # 1,N,1 => B,N,H*W
176
- # Multiply the log proba by their weights.
177
- ce = ce * at
178
-
179
- # Compute the loss mini-batch.
180
- # (1-p_t)^gamma * log(p_t) with reduced chance of overflow
181
- p = F .logsigmoid (- i * (t * 2.0 - 1.0 ))
182
- flat_loss : torch .Tensor = (p * self .gamma ).exp () * ce
183
-
184
- # Previously there was a mean over the last dimension, which did not
185
- # return a compatible BCE loss. To maintain backwards compatible
186
- # behavior we have a flag that performs this extra step, disable or
187
- # parameterize if necessary. (Or justify why the mean should be there)
188
- average_spatial_dims = True
179
+ raise ValueError ("the value/values of the `weight` should be no less than 0." )
180
+ # apply class_weight to loss
181
+ class_weight = class_weight .to (loss )
182
+ broadcast_dims = [- 1 ] + [1 ] * len (target .shape [2 :])
183
+ class_weight = class_weight .view (broadcast_dims )
184
+ loss = class_weight * loss
189
185
190
186
if self .reduction == LossReduction .SUM .value :
187
+ # Previously there was a mean over the last dimension, which did not
188
+ # return a compatible BCE loss. To maintain backwards compatible
189
+ # behavior we have a flag that performs this extra step, disable or
190
+ # parameterize if necessary. (Or justify why the mean should be there)
191
+ average_spatial_dims = True
191
192
if average_spatial_dims :
192
- flat_loss = flat_loss .mean (dim = - 1 )
193
- loss = flat_loss .sum ()
193
+ loss = loss .mean (dim = list ( range ( 2 , len ( target . shape ))) )
194
+ loss = loss .sum ()
194
195
elif self .reduction == LossReduction .MEAN .value :
195
- if average_spatial_dims :
196
- flat_loss = flat_loss .mean (dim = - 1 )
197
- loss = flat_loss .mean ()
196
+ loss = loss .mean ()
198
197
elif self .reduction == LossReduction .NONE .value :
199
- spacetime_dims = input .shape [2 :]
200
- loss = flat_loss .reshape ([b , n ] + list (spacetime_dims ))
198
+ pass
201
199
else :
202
200
raise ValueError (f'Unsupported reduction: { self .reduction } , available options are ["mean", "sum", "none"].' )
203
201
return loss
202
+
203
+
204
+ def softmax_focal_loss (
205
+ input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : Optional [float ] = None
206
+ ) -> torch .Tensor :
207
+ """
208
+ FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
209
+
210
+ where p_i = exp(s_i) / sum_j exp(s_j), t is the target (ground truth) class, and
211
+ s_j is the unnormalized score for class j.
212
+ """
213
+ input_ls = input .log_softmax (1 )
214
+ loss : torch .Tensor = - (1 - input_ls .exp ()).pow (gamma ) * input_ls * target
215
+
216
+ if alpha is not None :
217
+ # (1-alpha) for the background class and alpha for the other classes
218
+ alpha_fac = torch .tensor ([1 - alpha ] + [alpha ] * (target .shape [1 ] - 1 )).to (loss )
219
+ broadcast_dims = [- 1 ] + [1 ] * len (target .shape [2 :])
220
+ alpha_fac = alpha_fac .view (broadcast_dims )
221
+ loss = alpha_fac * loss
222
+
223
+ return loss
224
+
225
+
226
+ def sigmoid_focal_loss (
227
+ input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : Optional [float ] = None
228
+ ) -> torch .Tensor :
229
+ """
230
+ FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
231
+
232
+ where p = sigmoid(x), pt = p if label is 1 or 1 - p if label is 0
233
+ """
234
+ # computing binary cross entropy with logits
235
+ # equivalent to F.binary_cross_entropy_with_logits(input, target, reduction='none')
236
+ # see also https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/native/Loss.cpp#L231
237
+ max_val = (- input ).clamp (min = 0 )
238
+ loss : torch .Tensor = input - input * target + max_val + ((- max_val ).exp () + (- input - max_val ).exp ()).log ()
239
+
240
+ # sigmoid(-i) if t==1; sigmoid(i) if t==0 <=>
241
+ # 1-sigmoid(i) if t==1; sigmoid(i) if t==0 <=>
242
+ # 1-p if t==1; p if t==0 <=>
243
+ # pfac, that is, the term (1 - pt)
244
+ invprobs = F .logsigmoid (- input * (target * 2 - 1 )) # reduced chance of overflow
245
+ # (pfac.log() * gamma).exp() <=>
246
+ # pfac.log().exp() ^ gamma <=>
247
+ # pfac ^ gamma
248
+ loss = (invprobs * gamma ).exp () * loss
249
+
250
+ if alpha is not None :
251
+ # alpha if t==1; (1-alpha) if t==0
252
+ alpha_factor = target * alpha + (1 - target ) * (1 - alpha )
253
+ loss = alpha_factor * loss
254
+
255
+ return loss
0 commit comments