Skip to content

Commit ed85478

Browse files
committed
Fix comment
1 parent 168750a commit ed85478

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

python/metrics_pytorch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,10 @@ def loss_ownership_corr(
114114
assert mask_sum_hw.shape == (self.n,)
115115
magsq = torch.sum(torch.square(ownership_corr), dim=1, keepdim=True)
116116
# Smoothly map the magnitude from [0,infinity) -> [0,1) via tanh
117-
# So we need to multiply vectors x by a factor of tanh(|x|) / |x| = tanh(sqrt(|x|)) / sqrt(|x|)
117+
# So we need to multiply vectors x by a factor of tanh(|x|) / |x| = tanh(sqrt(|x|^2)) / sqrt(|x|^2)
118118
# But there's a division by 0 when |x| = 0, and also sqrt(0) has infinite gradient..
119119
# So to do this in a numerically stable way, we do this piecewise, using 3rd order taylor expansion
120-
# around 0.
120+
# around 0. Taylor expansion of tanh(sqrt(x)) / sqrt(x) is 1 - 1/3 x + 2/15 x^2 - 17/105 x^3.
121121
delta = 0.010
122122
sqrtmagsqboundedbelow = torch.sqrt(torch.clamp(magsq,min=0.008))
123123
magsqboundedabove = torch.clamp(magsq,max=0.012)
@@ -156,10 +156,10 @@ def loss_futurepos_corr(
156156
assert mask_sum_hw.shape == (self.n,)
157157
magsq = torch.sum(torch.square(futurepos_corr), dim=1, keepdim=True)
158158
# Smoothly map the magnitude from [0,infinity) -> [0,1) via tanh
159-
# So we need to multiply vectors x by a factor of tanh(|x|) / |x| = tanh(sqrt(|x|)) / sqrt(|x|)
159+
# So we need to multiply vectors x by a factor of tanh(|x|) / |x| = tanh(sqrt(|x|^2)) / sqrt(|x|^2)
160160
# But there's a division by 0 when |x| = 0, and also sqrt(0) has infinite gradient..
161161
# So to do this in a numerically stable way, we do this piecewise, using 3rd order taylor expansion
162-
# around 0.
162+
# around 0. Taylor expansion of tanh(sqrt(x)) / sqrt(x) is 1 - 1/3 x + 2/15 x^2 - 17/105 x^3.
163163
delta = 0.010
164164
sqrtmagsqboundedbelow = torch.sqrt(torch.clamp(magsq,min=0.008))
165165
magsqboundedabove = torch.clamp(magsq,max=0.012)

0 commit comments

Comments
 (0)