Skip to content

Commit 6847124

Browse files
committed
focal loss implementment
1 parent 92453c3 commit 6847124

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

utils/FocalLoss.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch.autograd import Variable
5+
6+
class FocalLoss(nn.Module):
7+
"""
8+
reference: https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py
9+
"""
10+
def __init__(self, gamma=2, alpha=0.25, size_average=True):
11+
super(FocalLoss, self).__init__()
12+
self.gamma = gamma
13+
self.alpha = alpha
14+
if isinstance(alpha,(float, int)): self.alpha = torch.Tensor([alpha,1-alpha])
15+
if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
16+
self.size_average = size_average
17+
18+
def forward(self, input, target):
19+
if input.dim()>2:
20+
input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W
21+
input = input.transpose(1,2) # N,C,H*W => N,H*W,C
22+
input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C
23+
target = target.view(-1,1)
24+
logpt = F.log_softmax(input, dim=1)
25+
logpt = logpt.gather(1,target)
26+
logpt = logpt.view(-1)
27+
pt = Variable(logpt.data.exp())
28+
29+
if self.alpha is not None:
30+
if self.alpha.type()!=input.data.type():
31+
self.alpha = self.alpha.type_as(input.data)
32+
at = self.alpha.gather(0,target.data.view(-1))
33+
logpt = logpt * Variable(at)
34+
35+
loss = -1 * (1-pt)**self.gamma * logpt
36+
if self.size_average: return loss.mean()
37+
else: return loss.sum()
38+
39+
if __name__ == '__main__':
40+
torch.manual_seed(1)
41+
inputs = Variable(torch.randn((10, 2)))
42+
targets = Variable(torch.LongTensor(10).random_(2))
43+
loss = FocalLoss()(inputs, targets)
44+
print(loss)

0 commit comments

Comments
 (0)