Skip to content

Commit 4367a34

Browse files
authored
Merge pull request #65 from kilianFatras/stochastic_OT
better implementation on stocjastic gradient updates
2 parents da5d07b + 2b8b180 commit 4367a34

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ot/stochastic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -617,8 +617,8 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
617617
update_alpha, update_beta = batch_grad_dual(a, b, M, reg, cur_alpha,
618618
cur_beta, batch_size,
619619
batch_alpha, batch_beta)
620-
cur_alpha += (lr / k) * update_alpha
621-
cur_beta += (lr / k) * update_beta
620+
cur_alpha[batch_alpha] += (lr / k) * update_alpha[batch_alpha]
621+
cur_beta[batch_beta] += (lr / k) * update_beta[batch_beta]
622622

623623
return cur_alpha, cur_beta
624624

0 commit comments

Comments
 (0)