Skip to content

Commit 08b859d

Browse files
authored
Merge branch 'master' into new_gpu
2 parents fa7f3dd + 4b05176 commit 08b859d

File tree

5 files changed

+229
-6
lines changed

5 files changed

+229
-6
lines changed

.github/ISSUE_TEMPLATE/bug_report.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
---
2+
name: Bug report
3+
about: Create a report to help us improve POT
4+
5+
---
6+
7+
**Describe the bug**
8+
A clear and concise description of what the bug is.
9+
10+
**To Reproduce**
11+
Steps to reproduce the behavior:
12+
1 ...
13+
2.
14+
15+
**Expected behavior**
16+
A clear and concise description of what you expected to happen.
17+
18+
**Screenshots**
19+
If applicable, add screenshots to help explain your problem.
20+
21+
**Desktop (please complete the following information):**
22+
- OS: [e.g. MacOSX, Windows, Ubuntu]
23+
- Python version [2.7,3.6]
24+
- How was POT installed [source, pip, conda]
25+
26+
Output of the following code snippet:
27+
```python
28+
import platform; print(platform.platform())
29+
import sys; print("Python", sys.version)
30+
import numpy; print("NumPy", numpy.__version__)
31+
import scipy; print("SciPy", scipy.__version__)
32+
import ot; print("POT", ot.__version__)
33+
```
34+
35+
**Additional context**
36+
Add any other context about the problem here.

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ This open source Python library provide several solvers for optimization problem
1414
It provides the following solvers:
1515

1616
* OT Network Flow solver for the linear program/ Earth Movers Distance [1].
17-
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (requires cupy).
17+
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] and greedy SInkhorn [22] with optional GPU implementation (requires cupy).
1818
* Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations [17].
1919
* Non regularized Wasserstein barycenters [16] with LP solver (only small scale).
20-
* Bregman projections for Wasserstein barycenter [3] and unmixing [4].
20+
* Bregman projections for Wasserstein barycenter [3], convolutional barycenter [21] and unmixing [4].
2121
* Optimal transport for domain adaptation with group lasso regularization [5]
2222
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
2323
* Linear OT [14] and Joint OT matrix and mapping estimation [8].
@@ -161,6 +161,7 @@ The contributors to this library are:
161161
* [Antoine Rolet](https://arolet.github.io/)
162162
* Erwan Vautier (Gromov-Wasserstein)
163163
* [Kilian Fatras](https://kilianfatras.github.io/)
164+
* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home)
164165

165166
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
166167

@@ -226,3 +227,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
226227
[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning
227228

228229
[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66.
230+
231+
[22] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31

docs/source/readme.rst

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@ It provides the following solvers:
1313
- OT Network Flow solver for the linear program/ Earth Movers Distance
1414
[1].
1515
- Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2]
16-
and stabilized version [9][10] with optional GPU implementation
17-
(requires cudamat).
16+
and stabilized version [9][10] and greedy SInkhorn [22] with optional
17+
GPU implementation (requires cudamat).
1818
- Smooth optimal transport solvers (dual and semi-dual) for KL and
1919
squared L2 regularizations [17].
2020
- Non regularized Wasserstein barycenters [16] with LP solver (only
2121
small scale).
22-
- Bregman projections for Wasserstein barycenter [3] and unmixing [4].
22+
- Bregman projections for Wasserstein barycenter [3], convolutional
23+
barycenter [21] and unmixing [4].
2324
- Optimal transport for domain adaptation with group lasso
2425
regularization [5]
2526
- Conditional gradient [6] and Generalized conditional gradient for
@@ -29,6 +30,9 @@ It provides the following solvers:
2930
pymanopt).
3031
- Gromov-Wasserstein distances and barycenters ([13] and regularized
3132
[12])
33+
- Stochastic Optimization for Large-scale Optimal Transport (semi-dual
34+
problem [18] and dual problem [19])
35+
- Non regularized free support Wasserstein barycenters [20].
3236

3337
Some demonstrations (both in Python and Jupyter Notebook format) are
3438
available in the examples folder.
@@ -219,6 +223,9 @@ The contributors to this library are:
219223
- `Stanislas Chambon <https://slasnista.github.io/>`__
220224
- `Antoine Rolet <https://arolet.github.io/>`__
221225
- Erwan Vautier (Gromov-Wasserstein)
226+
- `Kilian Fatras <https://kilianfatras.github.io/>`__
227+
- `Alain
228+
Rakotomamonjy <https://sites.google.com/site/alainrakotomamonjy/home>`__
222229

223230
This toolbox benefit a lot from open source research and we would like
224231
to thank the following persons for providing some code (in various
@@ -334,6 +341,31 @@ Optimal Transport <https://arxiv.org/abs/1710.06276>`__. Proceedings of
334341
the Twenty-First International Conference on Artificial Intelligence and
335342
Statistics (AISTATS).
336343

344+
[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) `Stochastic
345+
Optimization for Large-scale Optimal
346+
Transport <https://arxiv.org/abs/1605.08527>`__. Advances in Neural
347+
Information Processing Systems (2016).
348+
349+
[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet,
350+
A.& Blondel, M. `Large-scale Optimal Transport and Mapping
351+
Estimation <https://arxiv.org/pdf/1711.02283.pdf>`__. International
352+
Conference on Learning Representation (2018)
353+
354+
[20] Cuturi, M. and Doucet, A. (2014) `Fast Computation of Wasserstein
355+
Barycenters <http://proceedings.mlr.press/v32/cuturi14.html>`__.
356+
International Conference in Machine Learning
357+
358+
[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A.,
359+
Nguyen, A. & Guibas, L. (2015). `Convolutional wasserstein distances:
360+
Efficient optimal transportation on geometric
361+
domains <https://dl.acm.org/citation.cfm?id=2766963>`__. ACM
362+
Transactions on Graphics (TOG), 34(4), 66.
363+
364+
[22] J. Altschuler, J.Weed, P. Rigollet, (2017) `Near-linear time
365+
approximation algorithms for optimal transport via Sinkhorn
366+
iteration <https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf>`__,
367+
Advances in Neural Information Processing Systems (NIPS) 31
368+
337369
.. |PyPI version| image:: https://badge.fury.io/py/POT.svg
338370
:target: https://badge.fury.io/py/POT
339371
.. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg

ot/bregman.py

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
4747
reg : float
4848
Regularization term >0
4949
method : str
50-
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
50+
method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or
5151
'sinkhorn_epsilon_scaling', see those function for specific parameters
5252
numItermax : int, optional
5353
Max number of iterations
@@ -103,6 +103,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
103103
def sink():
104104
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
105105
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
106+
elif method.lower() == 'greenkhorn':
107+
def sink():
108+
return greenkhorn(a, b, M, reg, numItermax=numItermax,
109+
stopThr=stopThr, verbose=verbose, log=log)
106110
elif method.lower() == 'sinkhorn_stabilized':
107111
def sink():
108112
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
@@ -197,13 +201,16 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
197201
198202
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
199203
204+
[21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
205+
200206
201207
202208
See Also
203209
--------
204210
ot.lp.emd : Unregularized OT
205211
ot.optim.cg : General regularized OT
206212
ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
213+
ot.bregman.greenkhorn : Greenkhorn [21]
207214
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
208215
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
209216
@@ -410,6 +417,148 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
410417
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
411418

412419

420+
def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False):
421+
"""
422+
Solve the entropic regularization optimal transport problem and return the OT matrix
423+
424+
The algorithm used is based on the paper
425+
426+
Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration
427+
by Jason Altschuler, Jonathan Weed, Philippe Rigollet
428+
appeared at NIPS 2017
429+
430+
which is a stochastic version of the Sinkhorn-Knopp algorithm [2].
431+
432+
The function solves the following optimization problem:
433+
434+
.. math::
435+
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
436+
437+
s.t. \gamma 1 = a
438+
439+
\gamma^T 1= b
440+
441+
\gamma\geq 0
442+
where :
443+
444+
- M is the (ns,nt) metric cost matrix
445+
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
446+
- a and b are source and target weights (sum to 1)
447+
448+
449+
450+
Parameters
451+
----------
452+
a : np.ndarray (ns,)
453+
samples weights in the source domain
454+
b : np.ndarray (nt,) or np.ndarray (nt,nbb)
455+
samples in the target domain, compute sinkhorn with multiple targets
456+
and fixed M if b is a matrix (return OT loss + dual variables in log)
457+
M : np.ndarray (ns,nt)
458+
loss matrix
459+
reg : float
460+
Regularization term >0
461+
numItermax : int, optional
462+
Max number of iterations
463+
stopThr : float, optional
464+
Stop threshol on error (>0)
465+
log : bool, optional
466+
record log if True
467+
468+
469+
Returns
470+
-------
471+
gamma : (ns x nt) ndarray
472+
Optimal transportation matrix for the given parameters
473+
log : dict
474+
log dictionary return only if log==True in parameters
475+
476+
Examples
477+
--------
478+
479+
>>> import ot
480+
>>> a=[.5,.5]
481+
>>> b=[.5,.5]
482+
>>> M=[[0.,1.],[1.,0.]]
483+
>>> ot.bregman.greenkhorn(a,b,M,1)
484+
array([[ 0.36552929, 0.13447071],
485+
[ 0.13447071, 0.36552929]])
486+
487+
488+
References
489+
----------
490+
491+
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
492+
[22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
493+
494+
495+
See Also
496+
--------
497+
ot.lp.emd : Unregularized OT
498+
ot.optim.cg : General regularized OT
499+
500+
"""
501+
502+
n = a.shape[0]
503+
m = b.shape[0]
504+
505+
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
506+
K = np.empty_like(M)
507+
np.divide(M, -reg, out=K)
508+
np.exp(K, out=K)
509+
510+
u = np.full(n, 1. / n)
511+
v = np.full(m, 1. / m)
512+
G = u[:, np.newaxis] * K * v[np.newaxis, :]
513+
514+
viol = G.sum(1) - a
515+
viol_2 = G.sum(0) - b
516+
stopThr_val = 1
517+
if log:
518+
log['u'] = u
519+
log['v'] = v
520+
521+
for i in range(numItermax):
522+
i_1 = np.argmax(np.abs(viol))
523+
i_2 = np.argmax(np.abs(viol_2))
524+
m_viol_1 = np.abs(viol[i_1])
525+
m_viol_2 = np.abs(viol_2[i_2])
526+
stopThr_val = np.maximum(m_viol_1, m_viol_2)
527+
528+
if m_viol_1 > m_viol_2:
529+
old_u = u[i_1]
530+
u[i_1] = a[i_1] / (K[i_1, :].dot(v))
531+
G[i_1, :] = u[i_1] * K[i_1, :] * v
532+
533+
viol[i_1] = u[i_1] * K[i_1, :].dot(v) - a[i_1]
534+
viol_2 += (K[i_1, :].T * (u[i_1] - old_u) * v)
535+
536+
else:
537+
old_v = v[i_2]
538+
v[i_2] = b[i_2] / (K[:, i_2].T.dot(u))
539+
G[:, i_2] = u * K[:, i_2] * v[i_2]
540+
#aviol = (G@one_m - a)
541+
#aviol_2 = (G.T@one_n - b)
542+
viol += (-old_v + v[i_2]) * K[:, i_2] * u
543+
viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2]
544+
545+
#print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
546+
547+
if stopThr_val <= stopThr:
548+
break
549+
else:
550+
print('Warning: Algorithm did not converge')
551+
552+
if log:
553+
log['u'] = u
554+
log['v'] = v
555+
556+
if log:
557+
return G, log
558+
else:
559+
return G
560+
561+
413562
def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
414563
warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
415564
"""

test/test_bregman.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,14 @@ def test_sinkhorn_variants():
7171
Ges = ot.sinkhorn(
7272
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
7373
Gerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10)
74+
G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10)
7475

7576
# check values
7677
np.testing.assert_allclose(G0, Gs, atol=1e-05)
7778
np.testing.assert_allclose(G0, Ges, atol=1e-05)
7879
np.testing.assert_allclose(G0, Gerr)
80+
np.testing.assert_allclose(G0, G_green, atol=1e-5)
81+
print(G0, G_green)
7982

8083

8184
def test_bary():

0 commit comments

Comments
 (0)