Skip to content

Commit 9d6ccce

Browse files
authored
add softmax version to focal_loss (#6544)
Fixes #6510 . ### Description Add softmax version to Focal loss ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Qingpeng Li <[email protected]>
1 parent 4bd93c0 commit 9d6ccce

File tree

2 files changed

+293
-131
lines changed

2 files changed

+293
-131
lines changed

monai/losses/focal_loss.py

Lines changed: 106 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ class FocalLoss(_Loss):
2828
FocalLoss is an extension of BCEWithLogitsLoss that down-weights loss from
2929
high confidence correct predictions.
3030
31-
Reimplementation of the Focal Loss (with a build-in sigmoid activation) described in:
31+
Reimplementation of the Focal Loss described in:
3232
33-
- "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017
34-
- "AnatomyNet: Deep learning for fast and fully automated wholevolume segmentation of head and neck anatomy",
33+
- ["Focal Loss for Dense Object Detection"](https://arxiv.org/abs/1708.02002), T. Lin et al., ICCV 2017
34+
- "AnatomyNet: Deep learning for fast and fully automated whole-volume segmentation of head and neck anatomy",
3535
Zhu et al., Medical Physics 2018
3636
3737
Example:
@@ -70,19 +70,23 @@ def __init__(
7070
include_background: bool = True,
7171
to_onehot_y: bool = False,
7272
gamma: float = 2.0,
73+
alpha: float | None = None,
7374
weight: Sequence[float] | float | int | torch.Tensor | None = None,
7475
reduction: LossReduction | str = LossReduction.MEAN,
76+
use_softmax: bool = False,
7577
) -> None:
7678
"""
7779
Args:
78-
include_background: if False, channel index 0 (background category) is excluded from the calculation.
79-
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
80-
gamma: value of the exponent gamma in the definition of the Focal loss.
80+
include_background: if False, channel index 0 (background category) is excluded from the loss calculation.
81+
If False, `alpha` is invalid when using softmax.
82+
to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False.
83+
gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
84+
alpha: value of the alpha in the definition of the alpha-balanced Focal loss.
85+
The value should be in [0, 1]. Defaults to None.
8186
weight: weights to apply to the voxels of each class. If None no weights are applied.
82-
This corresponds to the weights `\alpha` in [1].
8387
The input can be a single value (same weight for all classes), a sequence of values (the length
84-
of the sequence should be the same as the number of classes, if not ``include_background``, the
85-
number should not include class 0).
88+
of the sequence should be the same as the number of classes. If not ``include_background``,
89+
the number of classes should not include the background category class 0).
8690
The value/values should be no less than 0. Defaults to None.
8791
reduction: {``"none"``, ``"mean"``, ``"sum"``}
8892
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
@@ -91,6 +95,9 @@ def __init__(
9195
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
9296
- ``"sum"``: the output will be summed.
9397
98+
use_softmax: whether to use softmax to transform the original logits into probabilities.
99+
If True, softmax is used. If False, sigmoid is used. Defaults to False.
100+
94101
Example:
95102
>>> import torch
96103
>>> from monai.losses import FocalLoss
@@ -103,14 +110,16 @@ def __init__(
103110
self.include_background = include_background
104111
self.to_onehot_y = to_onehot_y
105112
self.gamma = gamma
106-
self.weight: Sequence[float] | float | int | torch.Tensor | None = weight
113+
self.alpha = alpha
114+
self.weight = weight
115+
self.use_softmax = use_softmax
107116

108117
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
109118
"""
110119
Args:
111120
input: the shape should be BNH[WD], where N is the number of classes.
112121
The input should be the original logits since it will be transformed by
113-
a sigmoid in the forward function.
122+
a sigmoid/softmax in the forward function.
114123
target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
115124
116125
Raises:
@@ -141,63 +150,106 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
141150
if target.shape != input.shape:
142151
raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")
143152

144-
i = input
145-
t = target
146-
147-
# Change the shape of input and target to B x N x num_voxels.
148-
b, n = t.shape[:2]
149-
i = i.reshape(b, n, -1)
150-
t = t.reshape(b, n, -1)
151-
152-
# computing binary cross entropy with logits
153-
# see also https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/native/Loss.cpp#L231
154-
max_val = (-i).clamp(min=0)
155-
ce = i - i * t + max_val + ((-max_val).exp() + (-i - max_val).exp()).log()
153+
loss: Optional[torch.Tensor] = None
154+
input = input.float()
155+
target = target.float()
156+
if self.use_softmax:
157+
if not self.include_background and self.alpha is not None:
158+
self.alpha = None
159+
warnings.warn("`include_background=False`, `alpha` ignored when using softmax.")
160+
loss = softmax_focal_loss(input, target, self.gamma, self.alpha)
161+
else:
162+
loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha)
156163

157164
if self.weight is not None:
165+
# make sure the lengths of weights are equal to the number of classes
158166
class_weight: Optional[torch.Tensor] = None
167+
num_of_classes = target.shape[1]
159168
if isinstance(self.weight, (float, int)):
160-
class_weight = torch.as_tensor([self.weight] * i.size(1))
169+
class_weight = torch.as_tensor([self.weight] * num_of_classes)
161170
else:
162171
class_weight = torch.as_tensor(self.weight)
163-
if class_weight.size(0) != i.size(1):
172+
if class_weight.shape[0] != num_of_classes:
164173
raise ValueError(
165-
"the length of the weight sequence should be the same as the number of classes. "
166-
+ "If `include_background=False`, the number should not include class 0."
174+
"""the length of the `weight` sequence should be the same as the number of classes.
175+
If `include_background=False`, the weight should not include
176+
the background category class 0."""
167177
)
168178
if class_weight.min() < 0:
169-
raise ValueError("the value/values of weights should be no less than 0.")
170-
class_weight = class_weight.to(i)
171-
# Convert the weight to a map in which each voxel
172-
# has the weight associated with the ground-truth label
173-
# associated with this voxel in target.
174-
at = class_weight[None, :, None] # N => 1,N,1
175-
at = at.expand((t.size(0), -1, t.size(2))) # 1,N,1 => B,N,H*W
176-
# Multiply the log proba by their weights.
177-
ce = ce * at
178-
179-
# Compute the loss mini-batch.
180-
# (1-p_t)^gamma * log(p_t) with reduced chance of overflow
181-
p = F.logsigmoid(-i * (t * 2.0 - 1.0))
182-
flat_loss: torch.Tensor = (p * self.gamma).exp() * ce
183-
184-
# Previously there was a mean over the last dimension, which did not
185-
# return a compatible BCE loss. To maintain backwards compatible
186-
# behavior we have a flag that performs this extra step, disable or
187-
# parameterize if necessary. (Or justify why the mean should be there)
188-
average_spatial_dims = True
179+
raise ValueError("the value/values of the `weight` should be no less than 0.")
180+
# apply class_weight to loss
181+
class_weight = class_weight.to(loss)
182+
broadcast_dims = [-1] + [1] * len(target.shape[2:])
183+
class_weight = class_weight.view(broadcast_dims)
184+
loss = class_weight * loss
189185

190186
if self.reduction == LossReduction.SUM.value:
187+
# Previously there was a mean over the last dimension, which did not
188+
# return a compatible BCE loss. To maintain backwards compatible
189+
# behavior we have a flag that performs this extra step, disable or
190+
# parameterize if necessary. (Or justify why the mean should be there)
191+
average_spatial_dims = True
191192
if average_spatial_dims:
192-
flat_loss = flat_loss.mean(dim=-1)
193-
loss = flat_loss.sum()
193+
loss = loss.mean(dim=list(range(2, len(target.shape))))
194+
loss = loss.sum()
194195
elif self.reduction == LossReduction.MEAN.value:
195-
if average_spatial_dims:
196-
flat_loss = flat_loss.mean(dim=-1)
197-
loss = flat_loss.mean()
196+
loss = loss.mean()
198197
elif self.reduction == LossReduction.NONE.value:
199-
spacetime_dims = input.shape[2:]
200-
loss = flat_loss.reshape([b, n] + list(spacetime_dims))
198+
pass
201199
else:
202200
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
203201
return loss
202+
203+
204+
def softmax_focal_loss(
205+
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None
206+
) -> torch.Tensor:
207+
"""
208+
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
209+
210+
where p_i = exp(s_i) / sum_j exp(s_j), t is the target (ground truth) class, and
211+
s_j is the unnormalized score for class j.
212+
"""
213+
input_ls = input.log_softmax(1)
214+
loss: torch.Tensor = -(1 - input_ls.exp()).pow(gamma) * input_ls * target
215+
216+
if alpha is not None:
217+
# (1-alpha) for the background class and alpha for the other classes
218+
alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss)
219+
broadcast_dims = [-1] + [1] * len(target.shape[2:])
220+
alpha_fac = alpha_fac.view(broadcast_dims)
221+
loss = alpha_fac * loss
222+
223+
return loss
224+
225+
226+
def sigmoid_focal_loss(
227+
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None
228+
) -> torch.Tensor:
229+
"""
230+
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
231+
232+
where p = sigmoid(x), pt = p if label is 1 or 1 - p if label is 0
233+
"""
234+
# computing binary cross entropy with logits
235+
# equivalent to F.binary_cross_entropy_with_logits(input, target, reduction='none')
236+
# see also https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/native/Loss.cpp#L231
237+
max_val = (-input).clamp(min=0)
238+
loss: torch.Tensor = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log()
239+
240+
# sigmoid(-i) if t==1; sigmoid(i) if t==0 <=>
241+
# 1-sigmoid(i) if t==1; sigmoid(i) if t==0 <=>
242+
# 1-p if t==1; p if t==0 <=>
243+
# pfac, that is, the term (1 - pt)
244+
invprobs = F.logsigmoid(-input * (target * 2 - 1)) # reduced chance of overflow
245+
# (pfac.log() * gamma).exp() <=>
246+
# pfac.log().exp() ^ gamma <=>
247+
# pfac ^ gamma
248+
loss = (invprobs * gamma).exp() * loss
249+
250+
if alpha is not None:
251+
# alpha if t==1; (1-alpha) if t==0
252+
alpha_factor = target * alpha + (1 - target) * (1 - alpha)
253+
loss = alpha_factor * loss
254+
255+
return loss

0 commit comments

Comments
 (0)