@@ -28,10 +28,10 @@ class FocalLoss(_Loss):
2828 FocalLoss is an extension of BCEWithLogitsLoss that down-weights loss from
2929 high confidence correct predictions.
3030
31- Reimplementation of the Focal Loss (with a build-in sigmoid activation) described in:
31+ Reimplementation of the Focal Loss described in:
3232
33- - "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017
34- - "AnatomyNet: Deep learning for fast and fully automated whole‐ volume segmentation of head and neck anatomy",
33+ - [ "Focal Loss for Dense Object Detection"](https://arxiv.org/abs/1708.02002) , T. Lin et al., ICCV 2017
34+ - "AnatomyNet: Deep learning for fast and fully automated whole- volume segmentation of head and neck anatomy",
3535 Zhu et al., Medical Physics 2018
3636
3737 Example:
@@ -70,19 +70,23 @@ def __init__(
7070 include_background : bool = True ,
7171 to_onehot_y : bool = False ,
7272 gamma : float = 2.0 ,
73+ alpha : float | None = None ,
7374 weight : Sequence [float ] | float | int | torch .Tensor | None = None ,
7475 reduction : LossReduction | str = LossReduction .MEAN ,
76+ use_softmax : bool = False ,
7577 ) -> None :
7678 """
7779 Args:
78- include_background: if False, channel index 0 (background category) is excluded from the calculation.
79- to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
80- gamma: value of the exponent gamma in the definition of the Focal loss.
80+ include_background: if False, channel index 0 (background category) is excluded from the loss calculation.
81+ If False, `alpha` is invalid when using softmax.
82+ to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False.
83+ gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
84+ alpha: value of the alpha in the definition of the alpha-balanced Focal loss.
85+ The value should be in [0, 1]. Defaults to None.
8186 weight: weights to apply to the voxels of each class. If None no weights are applied.
82- This corresponds to the weights `\a lpha` in [1].
8387 The input can be a single value (same weight for all classes), a sequence of values (the length
84- of the sequence should be the same as the number of classes, if not ``include_background``, the
85- number should not include class 0).
88+ of the sequence should be the same as the number of classes. If not ``include_background``,
89+ the number of classes should not include the background category class 0).
8690 The value/values should be no less than 0. Defaults to None.
8791 reduction: {``"none"``, ``"mean"``, ``"sum"``}
8892 Specifies the reduction to apply to the output. Defaults to ``"mean"``.
@@ -91,6 +95,9 @@ def __init__(
9195 - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
9296 - ``"sum"``: the output will be summed.
9397
98+ use_softmax: whether to use softmax to transform the original logits into probabilities.
99+ If True, softmax is used. If False, sigmoid is used. Defaults to False.
100+
94101 Example:
95102 >>> import torch
96103 >>> from monai.losses import FocalLoss
@@ -103,14 +110,16 @@ def __init__(
103110 self .include_background = include_background
104111 self .to_onehot_y = to_onehot_y
105112 self .gamma = gamma
106- self .weight : Sequence [float ] | float | int | torch .Tensor | None = weight
113+ self .alpha = alpha
114+ self .weight = weight
115+ self .use_softmax = use_softmax
107116
108117 def forward (self , input : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
109118 """
110119 Args:
111120 input: the shape should be BNH[WD], where N is the number of classes.
112121 The input should be the original logits since it will be transformed by
113- a sigmoid in the forward function.
122+ a sigmoid/softmax in the forward function.
114123 target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
115124
116125 Raises:
@@ -141,63 +150,106 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
141150 if target .shape != input .shape :
142151 raise ValueError (f"ground truth has different shape ({ target .shape } ) from input ({ input .shape } )" )
143152
144- i = input
145- t = target
146-
147- # Change the shape of input and target to B x N x num_voxels.
148- b , n = t .shape [:2 ]
149- i = i .reshape (b , n , - 1 )
150- t = t .reshape (b , n , - 1 )
151-
152- # computing binary cross entropy with logits
153- # see also https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/native/Loss.cpp#L231
154- max_val = (- i ).clamp (min = 0 )
155- ce = i - i * t + max_val + ((- max_val ).exp () + (- i - max_val ).exp ()).log ()
153+ loss : Optional [torch .Tensor ] = None
154+ input = input .float ()
155+ target = target .float ()
156+ if self .use_softmax :
157+ if not self .include_background and self .alpha is not None :
158+ self .alpha = None
159+ warnings .warn ("`include_background=False`, `alpha` ignored when using softmax." )
160+ loss = softmax_focal_loss (input , target , self .gamma , self .alpha )
161+ else :
162+ loss = sigmoid_focal_loss (input , target , self .gamma , self .alpha )
156163
157164 if self .weight is not None :
165+ # make sure the lengths of weights are equal to the number of classes
158166 class_weight : Optional [torch .Tensor ] = None
167+ num_of_classes = target .shape [1 ]
159168 if isinstance (self .weight , (float , int )):
160- class_weight = torch .as_tensor ([self .weight ] * i . size ( 1 ) )
169+ class_weight = torch .as_tensor ([self .weight ] * num_of_classes )
161170 else :
162171 class_weight = torch .as_tensor (self .weight )
163- if class_weight .size ( 0 ) != i . size ( 1 ) :
172+ if class_weight .shape [ 0 ] != num_of_classes :
164173 raise ValueError (
165- "the length of the weight sequence should be the same as the number of classes. "
166- + "If `include_background=False`, the number should not include class 0."
174+ """the length of the `weight` sequence should be the same as the number of classes.
175+ If `include_background=False`, the weight should not include
176+ the background category class 0."""
167177 )
168178 if class_weight .min () < 0 :
169- raise ValueError ("the value/values of weights should be no less than 0." )
170- class_weight = class_weight .to (i )
171- # Convert the weight to a map in which each voxel
172- # has the weight associated with the ground-truth label
173- # associated with this voxel in target.
174- at = class_weight [None , :, None ] # N => 1,N,1
175- at = at .expand ((t .size (0 ), - 1 , t .size (2 ))) # 1,N,1 => B,N,H*W
176- # Multiply the log proba by their weights.
177- ce = ce * at
178-
179- # Compute the loss mini-batch.
180- # (1-p_t)^gamma * log(p_t) with reduced chance of overflow
181- p = F .logsigmoid (- i * (t * 2.0 - 1.0 ))
182- flat_loss : torch .Tensor = (p * self .gamma ).exp () * ce
183-
184- # Previously there was a mean over the last dimension, which did not
185- # return a compatible BCE loss. To maintain backwards compatible
186- # behavior we have a flag that performs this extra step, disable or
187- # parameterize if necessary. (Or justify why the mean should be there)
188- average_spatial_dims = True
179+ raise ValueError ("the value/values of the `weight` should be no less than 0." )
180+ # apply class_weight to loss
181+ class_weight = class_weight .to (loss )
182+ broadcast_dims = [- 1 ] + [1 ] * len (target .shape [2 :])
183+ class_weight = class_weight .view (broadcast_dims )
184+ loss = class_weight * loss
189185
190186 if self .reduction == LossReduction .SUM .value :
187+ # Previously there was a mean over the last dimension, which did not
188+ # return a compatible BCE loss. To maintain backwards compatible
189+ # behavior we have a flag that performs this extra step, disable or
190+ # parameterize if necessary. (Or justify why the mean should be there)
191+ average_spatial_dims = True
191192 if average_spatial_dims :
192- flat_loss = flat_loss .mean (dim = - 1 )
193- loss = flat_loss .sum ()
193+ loss = loss .mean (dim = list ( range ( 2 , len ( target . shape ))) )
194+ loss = loss .sum ()
194195 elif self .reduction == LossReduction .MEAN .value :
195- if average_spatial_dims :
196- flat_loss = flat_loss .mean (dim = - 1 )
197- loss = flat_loss .mean ()
196+ loss = loss .mean ()
198197 elif self .reduction == LossReduction .NONE .value :
199- spacetime_dims = input .shape [2 :]
200- loss = flat_loss .reshape ([b , n ] + list (spacetime_dims ))
198+ pass
201199 else :
202200 raise ValueError (f'Unsupported reduction: { self .reduction } , available options are ["mean", "sum", "none"].' )
203201 return loss
202+
203+
204+ def softmax_focal_loss (
205+ input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : Optional [float ] = None
206+ ) -> torch .Tensor :
207+ """
208+ FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
209+
210+ where p_i = exp(s_i) / sum_j exp(s_j), t is the target (ground truth) class, and
211+ s_j is the unnormalized score for class j.
212+ """
213+ input_ls = input .log_softmax (1 )
214+ loss : torch .Tensor = - (1 - input_ls .exp ()).pow (gamma ) * input_ls * target
215+
216+ if alpha is not None :
217+ # (1-alpha) for the background class and alpha for the other classes
218+ alpha_fac = torch .tensor ([1 - alpha ] + [alpha ] * (target .shape [1 ] - 1 )).to (loss )
219+ broadcast_dims = [- 1 ] + [1 ] * len (target .shape [2 :])
220+ alpha_fac = alpha_fac .view (broadcast_dims )
221+ loss = alpha_fac * loss
222+
223+ return loss
224+
225+
226+ def sigmoid_focal_loss (
227+ input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : Optional [float ] = None
228+ ) -> torch .Tensor :
229+ """
230+ FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
231+
232+ where p = sigmoid(x), pt = p if label is 1 or 1 - p if label is 0
233+ """
234+ # computing binary cross entropy with logits
235+ # equivalent to F.binary_cross_entropy_with_logits(input, target, reduction='none')
236+ # see also https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/native/Loss.cpp#L231
237+ max_val = (- input ).clamp (min = 0 )
238+ loss : torch .Tensor = input - input * target + max_val + ((- max_val ).exp () + (- input - max_val ).exp ()).log ()
239+
240+ # sigmoid(-i) if t==1; sigmoid(i) if t==0 <=>
241+ # 1-sigmoid(i) if t==1; sigmoid(i) if t==0 <=>
242+ # 1-p if t==1; p if t==0 <=>
243+ # pfac, that is, the term (1 - pt)
244+ invprobs = F .logsigmoid (- input * (target * 2 - 1 )) # reduced chance of overflow
245+ # (pfac.log() * gamma).exp() <=>
246+ # pfac.log().exp() ^ gamma <=>
247+ # pfac ^ gamma
248+ loss = (invprobs * gamma ).exp () * loss
249+
250+ if alpha is not None :
251+ # alpha if t==1; (1-alpha) if t==0
252+ alpha_factor = target * alpha + (1 - target ) * (1 - alpha )
253+ loss = alpha_factor * loss
254+
255+ return loss
0 commit comments