1
+ import torch
2
+ import torch .nn as nn
3
+ import torch .nn .functional as F
4
+ from torch .autograd import Variable
5
+
6
+ class FocalLoss (nn .Module ):
7
+ """
8
+ reference: https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py
9
+ """
10
+ def __init__ (self , gamma = 2 , alpha = 0.25 , size_average = True ):
11
+ super (FocalLoss , self ).__init__ ()
12
+ self .gamma = gamma
13
+ self .alpha = alpha
14
+ if isinstance (alpha ,(float , int )): self .alpha = torch .Tensor ([alpha ,1 - alpha ])
15
+ if isinstance (alpha ,list ): self .alpha = torch .Tensor (alpha )
16
+ self .size_average = size_average
17
+
18
+ def forward (self , input , target ):
19
+ if input .dim ()> 2 :
20
+ input = input .view (input .size (0 ),input .size (1 ),- 1 ) # N,C,H,W => N,C,H*W
21
+ input = input .transpose (1 ,2 ) # N,C,H*W => N,H*W,C
22
+ input = input .contiguous ().view (- 1 ,input .size (2 )) # N,H*W,C => N*H*W,C
23
+ target = target .view (- 1 ,1 )
24
+ logpt = F .log_softmax (input , dim = 1 )
25
+ logpt = logpt .gather (1 ,target )
26
+ logpt = logpt .view (- 1 )
27
+ pt = Variable (logpt .data .exp ())
28
+
29
+ if self .alpha is not None :
30
+ if self .alpha .type ()!= input .data .type ():
31
+ self .alpha = self .alpha .type_as (input .data )
32
+ at = self .alpha .gather (0 ,target .data .view (- 1 ))
33
+ logpt = logpt * Variable (at )
34
+
35
+ loss = - 1 * (1 - pt )** self .gamma * logpt
36
+ if self .size_average : return loss .mean ()
37
+ else : return loss .sum ()
38
+
39
+ if __name__ == '__main__' :
40
+ torch .manual_seed (1 )
41
+ inputs = Variable (torch .randn ((10 , 2 )))
42
+ targets = Variable (torch .LongTensor (10 ).random_ (2 ))
43
+ loss = FocalLoss ()(inputs , targets )
44
+ print (loss )
0 commit comments