@@ -114,10 +114,10 @@ def loss_ownership_corr(
114114 assert mask_sum_hw .shape == (self .n ,)
115115 magsq = torch .sum (torch .square (ownership_corr ), dim = 1 , keepdim = True )
116116 # Smoothly map the magnitude from [0,infinity) -> [0,1) via tanh
117- # So we need to multiply vectors x by a factor of tanh(|x|) / |x| = tanh(sqrt(|x|)) / sqrt(|x|)
117+ # So we need to multiply vectors x by a factor of tanh(|x|) / |x| = tanh(sqrt(|x|^2 )) / sqrt(|x|^2 )
118118 # But there's a division by 0 when |x| = 0, and also sqrt(0) has infinite gradient..
119119 # So to do this in a numerically stable way, we do this piecewise, using 3rd order taylor expansion
120- # around 0.
120+ # around 0. Taylor expansion of tanh(sqrt(x)) / sqrt(x) is 1 - 1/3 x + 2/15 x^2 - 17/105 x^3.
121121 delta = 0.010
122122 sqrtmagsqboundedbelow = torch .sqrt (torch .clamp (magsq ,min = 0.008 ))
123123 magsqboundedabove = torch .clamp (magsq ,max = 0.012 )
@@ -156,10 +156,10 @@ def loss_futurepos_corr(
156156 assert mask_sum_hw .shape == (self .n ,)
157157 magsq = torch .sum (torch .square (futurepos_corr ), dim = 1 , keepdim = True )
158158 # Smoothly map the magnitude from [0,infinity) -> [0,1) via tanh
159- # So we need to multiply vectors x by a factor of tanh(|x|) / |x| = tanh(sqrt(|x|)) / sqrt(|x|)
159+ # So we need to multiply vectors x by a factor of tanh(|x|) / |x| = tanh(sqrt(|x|^2 )) / sqrt(|x|^2 )
160160 # But there's a division by 0 when |x| = 0, and also sqrt(0) has infinite gradient..
161161 # So to do this in a numerically stable way, we do this piecewise, using 3rd order taylor expansion
162- # around 0.
162+ # around 0. Taylor expansion of tanh(sqrt(x)) / sqrt(x) is 1 - 1/3 x + 2/15 x^2 - 17/105 x^3.
163163 delta = 0.010
164164 sqrtmagsqboundedbelow = torch .sqrt (torch .clamp (magsq ,min = 0.008 ))
165165 magsqboundedabove = torch .clamp (magsq ,max = 0.012 )
0 commit comments