Skip to content

Commit 4ae6afa

Browse files
authored
Merge pull request #197 from basf/feat/splines
Fix A001 and A002
2 parents 7f49ec8 + 45e3a8a commit 4ae6afa

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

mambular/arch_utils/layer_utils/sparsemax.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ class SparsemaxFunction(Function):
3636
"""
3737

3838
@staticmethod
39-
def forward(ctx, input, dim=-1):
39+
def forward(ctx, input_, dim=-1):
4040
"""
4141
Forward pass of sparsemax: a normalizing, sparse transformation.
4242
4343
Parameters
4444
----------
45-
input : torch.Tensor
45+
input_ : torch.Tensor
4646
The input tensor on which sparsemax will be applied.
4747
dim : int, optional
4848
Dimension along which to apply sparsemax. Default is -1.
@@ -53,10 +53,10 @@ def forward(ctx, input, dim=-1):
5353
A tensor with the same shape as the input, with sparsemax applied.
5454
"""
5555
ctx.dim = dim
56-
max_val, _ = input.max(dim=dim, keepdim=True)
57-
input -= max_val # Numerical stability trick, as with softmax.
58-
tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim)
59-
output = torch.clamp(input - tau, min=0)
56+
max_val, _ = input_.max(dim=dim, keepdim=True)
57+
input_ -= max_val # Numerical stability trick, as with softmax.
58+
tau, supp_size = SparsemaxFunction._threshold_and_support(input_, dim=dim)
59+
output = torch.clamp(input_ - tau, min=0)
6060
ctx.save_for_backward(supp_size, output)
6161
return output
6262

@@ -86,13 +86,13 @@ def backward(ctx, grad_output): # type: ignore
8686
return grad_input, None
8787

8888
@staticmethod
89-
def _threshold_and_support(input, dim=-1):
89+
def _threshold_and_support(input_, dim=-1):
9090
"""
9191
Computes the threshold and support for sparsemax.
9292
9393
Parameters
9494
----------
95-
input : torch.Tensor
95+
input_ : torch.Tensor
9696
The input tensor on which to compute the threshold and support.
9797
dim : int, optional
9898
Dimension along which to compute the threshold and support. Default is -1.
@@ -103,14 +103,14 @@ def _threshold_and_support(input, dim=-1):
103103
- torch.Tensor : The threshold value for sparsemax.
104104
- torch.Tensor : The support size tensor.
105105
"""
106-
input_srt, _ = torch.sort(input, descending=True, dim=dim)
106+
input_srt, _ = torch.sort(input_, descending=True, dim=dim)
107107
input_cumsum = input_srt.cumsum(dim) - 1
108-
rhos = _make_ix_like(input, dim)
108+
rhos = _make_ix_like(input_, dim)
109109
support = rhos * input_srt > input_cumsum
110110

111111
support_size = support.sum(dim=dim).unsqueeze(dim)
112112
tau = input_cumsum.gather(dim, support_size - 1)
113-
tau /= support_size.to(input.dtype)
113+
tau /= support_size.to(input_.dtype)
114114
return tau, support_size
115115

116116

0 commit comments

Comments
 (0)