@@ -36,13 +36,13 @@ class SparsemaxFunction(Function):
36
36
"""
37
37
38
38
@staticmethod
39
- def forward (ctx , input , dim = - 1 ):
39
+ def forward (ctx , input_ , dim = - 1 ):
40
40
"""
41
41
Forward pass of sparsemax: a normalizing, sparse transformation.
42
42
43
43
Parameters
44
44
----------
45
- input : torch.Tensor
45
+ input_ : torch.Tensor
46
46
The input tensor on which sparsemax will be applied.
47
47
dim : int, optional
48
48
Dimension along which to apply sparsemax. Default is -1.
@@ -53,10 +53,10 @@ def forward(ctx, input, dim=-1):
53
53
A tensor with the same shape as the input, with sparsemax applied.
54
54
"""
55
55
ctx .dim = dim
56
- max_val , _ = input .max (dim = dim , keepdim = True )
57
- input -= max_val # Numerical stability trick, as with softmax.
58
- tau , supp_size = SparsemaxFunction ._threshold_and_support (input , dim = dim )
59
- output = torch .clamp (input - tau , min = 0 )
56
+ max_val , _ = input_ .max (dim = dim , keepdim = True )
57
+ input_ -= max_val # Numerical stability trick, as with softmax.
58
+ tau , supp_size = SparsemaxFunction ._threshold_and_support (input_ , dim = dim )
59
+ output = torch .clamp (input_ - tau , min = 0 )
60
60
ctx .save_for_backward (supp_size , output )
61
61
return output
62
62
@@ -86,13 +86,13 @@ def backward(ctx, grad_output): # type: ignore
86
86
return grad_input , None
87
87
88
88
@staticmethod
89
- def _threshold_and_support (input , dim = - 1 ):
89
+ def _threshold_and_support (input_ , dim = - 1 ):
90
90
"""
91
91
Computes the threshold and support for sparsemax.
92
92
93
93
Parameters
94
94
----------
95
- input : torch.Tensor
95
+ input_ : torch.Tensor
96
96
The input tensor on which to compute the threshold and support.
97
97
dim : int, optional
98
98
Dimension along which to compute the threshold and support. Default is -1.
@@ -103,14 +103,14 @@ def _threshold_and_support(input, dim=-1):
103
103
- torch.Tensor : The threshold value for sparsemax.
104
104
- torch.Tensor : The support size tensor.
105
105
"""
106
- input_srt , _ = torch .sort (input , descending = True , dim = dim )
106
+ input_srt , _ = torch .sort (input_ , descending = True , dim = dim )
107
107
input_cumsum = input_srt .cumsum (dim ) - 1
108
- rhos = _make_ix_like (input , dim )
108
+ rhos = _make_ix_like (input_ , dim )
109
109
support = rhos * input_srt > input_cumsum
110
110
111
111
support_size = support .sum (dim = dim ).unsqueeze (dim )
112
112
tau = input_cumsum .gather (dim , support_size - 1 )
113
- tau /= support_size .to (input .dtype )
113
+ tau /= support_size .to (input_ .dtype )
114
114
return tau , support_size
115
115
116
116
0 commit comments