Skip to content

Commit da5d07b

Browse files
authored
Merge pull request #62 from kilianFatras/stochastic_OT
Debug and speedup SGD stochastic OT
2 parents 5180023 + 15f4b29 commit da5d07b

File tree

3 files changed

+89
-152
lines changed

3 files changed

+89
-152
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ It provides the following solvers:
1717
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (requires cudamat).
1818
* Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations [17].
1919
* Non regularized Wasserstein barycenters [16] with LP solver (only small scale).
20-
* Non regularized free support Wasserstein barycenters [20].
2120
* Bregman projections for Wasserstein barycenter [3] and unmixing [4].
2221
* Optimal transport for domain adaptation with group lasso regularization [5]
2322
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
2423
* Linear OT [14] and Joint OT matrix and mapping estimation [8].
2524
* Wasserstein Discriminant Analysis [11] (requires autograd + pymanopt).
2625
* Gromov-Wasserstein distances and barycenters ([13] and regularized [12])
2726
* Stochastic Optimization for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19])
27+
* Non regularized free support Wasserstein barycenters [20].
2828

2929
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
3030

@@ -164,7 +164,7 @@ The contributors to this library are:
164164
* [Stanislas Chambon](https://slasnista.github.io/)
165165
* [Antoine Rolet](https://arolet.github.io/)
166166
* Erwan Vautier (Gromov-Wasserstein)
167-
* [Kilian Fatras](https://kilianfatras.github.io/) (Stochastic optimization)
167+
* [Kilian Fatras](https://kilianfatras.github.io/)
168168

169169
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
170170

@@ -223,8 +223,8 @@ You can also post bug reports and feature requests in Github issues. Make sure t
223223

224224
[17] Blondel, M., Seguy, V., & Rolet, A. (2018). [Smooth and Sparse Optimal Transport](https://arxiv.org/abs/1710.06276). Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
225225

226-
[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](arXiv preprint arxiv:1605.08527). Advances in Neural Information Processing Systems (2016).
226+
[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](https://arxiv.org/abs/1605.08527). Advances in Neural Information Processing Systems (2016).
227227

228228
[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018)
229229

230-
[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning
230+
[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning

ot/stochastic.py

Lines changed: 46 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -435,113 +435,40 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
435435
##############################################################################
436436

437437

438-
def batch_grad_dual_alpha(M, reg, alpha, beta, batch_size, batch_alpha,
439-
batch_beta):
438+
def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
439+
batch_beta):
440440
'''
441441
Computes the partial gradient of F_\W_varepsilon
442442
443443
Compute the partial gradient of the dual problem:
444444
445445
..math:
446446
\forall i in batch_alpha,
447-
grad_alpha_i = 1 * batch_size -
448-
sum_{j in batch_beta} exp((alpha_i + beta_j - M_{i,j})/reg)
447+
grad_alpha_i = alpha_i * batch_size/len(beta) -
448+
sum_{j in batch_beta} exp((alpha_i + beta_j - M_{i,j})/reg)
449+
* a_i * b_j
449450
450-
where :
451-
- M is the (ns,nt) metric cost matrix
452-
- alpha, beta are dual variables in R^ixR^J
453-
- reg is the regularization term
454-
- batch_alpha and batch_beta are list of index
455-
456-
The algorithm used for solving the dual problem is the SGD algorithm
457-
as proposed in [19]_ [alg.1]
458-
459-
Parameters
460-
----------
461-
462-
reg : float number,
463-
Regularization term > 0
464-
M : np.ndarray(ns, nt),
465-
cost matrix
466-
alpha : np.ndarray(ns,)
467-
dual variable
468-
beta : np.ndarray(nt,)
469-
dual variable
470-
batch_size : int number
471-
size of the batch
472-
batch_alpha : np.ndarray(bs,)
473-
batch of index of alpha
474-
batch_beta : np.ndarray(bs,)
475-
batch of index of beta
476-
477-
Returns
478-
-------
479-
480-
grad : np.ndarray(ns,)
481-
partial grad F in alpha
482-
483-
Examples
484-
--------
485-
486-
>>> n_source = 7
487-
>>> n_target = 4
488-
>>> reg = 1
489-
>>> numItermax = 20000
490-
>>> lr = 0.1
491-
>>> batch_size = 3
492-
>>> log = True
493-
>>> a = ot.utils.unif(n_source)
494-
>>> b = ot.utils.unif(n_target)
495-
>>> rng = np.random.RandomState(0)
496-
>>> X_source = rng.randn(n_source, 2)
497-
>>> Y_target = rng.randn(n_target, 2)
498-
>>> M = ot.dist(X_source, Y_target)
499-
>>> sgd_dual_pi, log = stochastic.solve_dual_entropic(a, b, M, reg,
500-
batch_size,
501-
numItermax, lr, log)
502-
>>> print(log['alpha'], log['beta'])
503-
>>> print(sgd_dual_pi)
504-
505-
References
506-
----------
507-
508-
[Seguy et al., 2018] :
509-
International Conference on Learning Representation (2018),
510-
arXiv preprint arxiv:1711.02283.
511-
'''
512-
513-
grad_alpha = np.zeros(batch_size)
514-
grad_alpha[:] = batch_size
515-
for j in batch_beta:
516-
grad_alpha -= np.exp((alpha[batch_alpha] + beta[j] -
517-
M[batch_alpha, j]) / reg)
518-
return grad_alpha
519-
520-
521-
def batch_grad_dual_beta(M, reg, alpha, beta, batch_size, batch_alpha,
522-
batch_beta):
523-
'''
524-
Computes the partial gradient of F_\W_varepsilon
525-
526-
Compute the partial gradient of the dual problem:
527-
528-
..math:
529-
\forall j in batch_beta,
530-
grad_beta_j = 1 * batch_size -
451+
\forall j in batch_alpha,
452+
grad_beta_j = beta_j * batch_size/len(alpha) -
531453
sum_{i in batch_alpha} exp((alpha_i + beta_j - M_{i,j})/reg)
532-
454+
* a_i * b_j
533455
where :
534456
- M is the (ns,nt) metric cost matrix
535457
- alpha, beta are dual variables in R^ixR^J
536458
- reg is the regularization term
537-
- batch_alpha and batch_beta are list of index
459+
- batch_alpha and batch_beta are lists of index
460+
- a and b are source and target weights (sum to 1)
461+
538462
539463
The algorithm used for solving the dual problem is the SGD algorithm
540464
as proposed in [19]_ [alg.1]
541465
542466
Parameters
543467
----------
544-
468+
a : np.ndarray(ns,),
469+
source measure
470+
b : np.ndarray(nt,),
471+
target measure
545472
M : np.ndarray(ns, nt),
546473
cost matrix
547474
reg : float number,
@@ -561,7 +488,7 @@ def batch_grad_dual_beta(M, reg, alpha, beta, batch_size, batch_alpha,
561488
-------
562489
563490
grad : np.ndarray(ns,)
564-
partial grad F in beta
491+
partial grad F
565492
566493
Examples
567494
--------
@@ -591,19 +518,22 @@ def batch_grad_dual_beta(M, reg, alpha, beta, batch_size, batch_alpha,
591518
[Seguy et al., 2018] :
592519
International Conference on Learning Representation (2018),
593520
arXiv preprint arxiv:1711.02283.
594-
595521
'''
596522

597-
grad_beta = np.zeros(batch_size)
598-
grad_beta[:] = batch_size
599-
for i in batch_alpha:
600-
grad_beta -= np.exp((alpha[i] +
601-
beta[batch_beta] - M[i, batch_beta]) / reg)
602-
return grad_beta
523+
G = - (np.exp((alpha[batch_alpha, None] + beta[None, batch_beta] -
524+
M[batch_alpha, :][:, batch_beta]) / reg) *
525+
a[batch_alpha, None] * b[None, batch_beta])
526+
grad_beta = np.zeros(np.shape(M)[1])
527+
grad_alpha = np.zeros(np.shape(M)[0])
528+
grad_beta[batch_beta] = (b[batch_beta] * len(batch_alpha) / np.shape(M)[0] +
529+
G.sum(0))
530+
grad_alpha[batch_alpha] = (a[batch_alpha] * len(batch_beta) /
531+
np.shape(M)[1] + G.sum(1))
532+
533+
return grad_alpha, grad_beta
603534

604535

605-
def sgd_entropic_regularization(M, reg, batch_size, numItermax, lr,
606-
alternate=True):
536+
def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
607537
'''
608538
Compute the sgd algorithm to solve the regularized discrete measures
609539
optimal transport dual problem
@@ -623,7 +553,10 @@ def sgd_entropic_regularization(M, reg, batch_size, numItermax, lr,
623553
624554
Parameters
625555
----------
626-
556+
a : np.ndarray(ns,),
557+
source measure
558+
b : np.ndarray(nt,),
559+
target measure
627560
M : np.ndarray(ns, nt),
628561
cost matrix
629562
reg : float number,
@@ -634,8 +567,6 @@ def sgd_entropic_regularization(M, reg, batch_size, numItermax, lr,
634567
number of iteration
635568
lr : float number
636569
learning rate
637-
alternate : bool, optional
638-
alternating algorithm
639570
640571
Returns
641572
-------
@@ -662,8 +593,8 @@ def sgd_entropic_regularization(M, reg, batch_size, numItermax, lr,
662593
>>> Y_target = rng.randn(n_target, 2)
663594
>>> M = ot.dist(X_source, Y_target)
664595
>>> sgd_dual_pi, log = stochastic.solve_dual_entropic(a, b, M, reg,
665-
batch_size,
666-
numItermax, lr, log)
596+
batch_size,
597+
numItermax, lr, log)
667598
>>> print(log['alpha'], log['beta'])
668599
>>> print(sgd_dual_pi)
669600
@@ -677,35 +608,17 @@ def sgd_entropic_regularization(M, reg, batch_size, numItermax, lr,
677608

678609
n_source = np.shape(M)[0]
679610
n_target = np.shape(M)[1]
680-
cur_alpha = np.random.randn(n_source)
681-
cur_beta = np.random.randn(n_target)
682-
if alternate:
683-
for cur_iter in range(numItermax):
684-
k = np.sqrt(cur_iter + 1)
685-
batch_alpha = np.random.choice(n_source, batch_size, replace=False)
686-
batch_beta = np.random.choice(n_target, batch_size, replace=False)
687-
grad_F_alpha = batch_grad_dual_alpha(M, reg, cur_alpha, cur_beta,
688-
batch_size, batch_alpha,
689-
batch_beta)
690-
cur_alpha[batch_alpha] += (lr / k) * grad_F_alpha
691-
grad_F_beta = batch_grad_dual_beta(M, reg, cur_alpha, cur_beta,
692-
batch_size, batch_alpha,
693-
batch_beta)
694-
cur_beta[batch_beta] += (lr / k) * grad_F_beta
695-
696-
else:
697-
for cur_iter in range(numItermax):
698-
k = np.sqrt(cur_iter + 1)
699-
batch_alpha = np.random.choice(n_source, batch_size, replace=False)
700-
batch_beta = np.random.choice(n_target, batch_size, replace=False)
701-
grad_F_alpha = batch_grad_dual_alpha(M, reg, cur_alpha, cur_beta,
702-
batch_size, batch_alpha,
703-
batch_beta)
704-
grad_F_beta = batch_grad_dual_beta(M, reg, cur_alpha, cur_beta,
705-
batch_size, batch_alpha,
706-
batch_beta)
707-
cur_alpha[batch_alpha] += (lr / k) * grad_F_alpha
708-
cur_beta[batch_beta] += (lr / k) * grad_F_beta
611+
cur_alpha = np.zeros(n_source)
612+
cur_beta = np.zeros(n_target)
613+
for cur_iter in range(numItermax):
614+
k = np.sqrt(cur_iter + 1)
615+
batch_alpha = np.random.choice(n_source, batch_size, replace=False)
616+
batch_beta = np.random.choice(n_target, batch_size, replace=False)
617+
update_alpha, update_beta = batch_grad_dual(a, b, M, reg, cur_alpha,
618+
cur_beta, batch_size,
619+
batch_alpha, batch_beta)
620+
cur_alpha += (lr / k) * update_alpha
621+
cur_beta += (lr / k) * update_beta
709622

710623
return cur_alpha, cur_beta
711624

@@ -787,7 +700,7 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
787700
arXiv preprint arxiv:1711.02283.
788701
'''
789702

790-
opt_alpha, opt_beta = sgd_entropic_regularization(M, reg, batch_size,
703+
opt_alpha, opt_beta = sgd_entropic_regularization(a, b, M, reg, batch_size,
791704
numItermax, lr)
792705
pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg) *
793706
a[:, None] * b[None, :])

test/test_stochastic.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def test_sag_asgd_sinkhorn():
9797

9898
x = rng.randn(n, 2)
9999
u = ot.utils.unif(n)
100-
zero = np.zeros(n)
101100
M = ot.dist(x, x)
102101

103102
G_asgd = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd",
@@ -108,13 +107,13 @@ def test_sag_asgd_sinkhorn():
108107

109108
# check constratints
110109
np.testing.assert_allclose(
111-
zero, (G_sag - G_sinkhorn).sum(1), atol=1e-03) # cf convergence sag
110+
G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-03)
112111
np.testing.assert_allclose(
113-
zero, (G_sag - G_sinkhorn).sum(0), atol=1e-03) # cf convergence sag
112+
G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-03)
114113
np.testing.assert_allclose(
115-
zero, (G_asgd - G_sinkhorn).sum(1), atol=1e-03) # cf convergence asgd
114+
G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
116115
np.testing.assert_allclose(
117-
zero, (G_asgd - G_sinkhorn).sum(0), atol=1e-03) # cf convergence asgd
116+
G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
118117
np.testing.assert_allclose(
119118
G_sag, G_sinkhorn, atol=1e-03) # cf convergence sag
120119
np.testing.assert_allclose(
@@ -137,8 +136,8 @@ def test_stochastic_dual_sgd():
137136
# test sgd
138137
n = 10
139138
reg = 1
140-
numItermax = 300000
141-
batch_size = 8
139+
numItermax = 15000
140+
batch_size = 10
142141
rng = np.random.RandomState(0)
143142

144143
x = rng.randn(n, 2)
@@ -151,9 +150,9 @@ def test_stochastic_dual_sgd():
151150

152151
# check constratints
153152
np.testing.assert_allclose(
154-
u, G.sum(1), atol=1e-02) # cf convergence sgd
153+
u, G.sum(1), atol=1e-03) # cf convergence sgd
155154
np.testing.assert_allclose(
156-
u, G.sum(0), atol=1e-02) # cf convergence sgd
155+
u, G.sum(0), atol=1e-03) # cf convergence sgd
157156

158157

159158
#############################################################################
@@ -168,13 +167,13 @@ def test_dual_sgd_sinkhorn():
168167
# test all dual algorithms
169168
n = 10
170169
reg = 1
171-
nb_iter = 300000
172-
batch_size = 8
170+
nb_iter = 150000
171+
batch_size = 10
173172
rng = np.random.RandomState(0)
174173

174+
# Test uniform
175175
x = rng.randn(n, 2)
176176
u = ot.utils.unif(n)
177-
zero = np.zeros(n)
178177
M = ot.dist(x, x)
179178

180179
G_sgd = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size,
@@ -184,8 +183,33 @@ def test_dual_sgd_sinkhorn():
184183

185184
# check constratints
186185
np.testing.assert_allclose(
187-
zero, (G_sgd - G_sinkhorn).sum(1), atol=1e-02) # cf convergence sgd
186+
G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
187+
np.testing.assert_allclose(
188+
G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
189+
np.testing.assert_allclose(
190+
G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd
191+
192+
# Test gaussian
193+
n = 30
194+
reg = 1
195+
batch_size = 30
196+
197+
a = ot.datasets.make_1D_gauss(n, 15, 5) # m= mean, s= std
198+
b = ot.datasets.make_1D_gauss(n, 15, 5)
199+
X_source = np.arange(n, dtype=np.float64)
200+
Y_target = np.arange(n, dtype=np.float64)
201+
M = ot.dist(X_source.reshape((n, 1)), Y_target.reshape((n, 1)))
202+
M /= M.max()
203+
204+
G_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size,
205+
numItermax=nb_iter)
206+
207+
G_sinkhorn = ot.sinkhorn(a, b, M, reg)
208+
209+
# check constratints
210+
np.testing.assert_allclose(
211+
G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
188212
np.testing.assert_allclose(
189-
zero, (G_sgd - G_sinkhorn).sum(0), atol=1e-02) # cf convergence sgd
213+
G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
190214
np.testing.assert_allclose(
191-
G_sgd, G_sinkhorn, atol=1e-02) # cf convergence sgd
215+
G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd

0 commit comments

Comments
 (0)