Skip to content

Commit c5108ef

Browse files
authored
Merge pull request #80 from kilianFatras/master
add empirical sinkhorn and sinkhorn divergence functions
2 parents 2384380 + 17fa4f9 commit c5108ef

File tree

5 files changed

+401
-1
lines changed

5 files changed

+401
-1
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ This open source Python library provide several solvers for optimization problem
1515
It provides the following solvers:
1616

1717
* OT Network Flow solver for the linear program/ Earth Movers Distance [1].
18-
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] and greedy SInkhorn [22] with optional GPU implementation (requires cupy).
18+
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2], stabilized version [9][10] and greedy Sinkhorn [22] with optional GPU implementation (requires cupy).
19+
* Sinkhorn divergence [23] and entropic regularization OT from empirical data.
1920
* Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations [17].
2021
* Non regularized Wasserstein barycenters [16] with LP solver (only small scale).
2122
* Bregman projections for Wasserstein barycenter [3], convolutional barycenter [21] and unmixing [4].
@@ -230,3 +231,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
230231
[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66.
231232

232233
[22] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31
234+
235+
[23] Aude, G., Peyré, G., Cuturi, M., [Learning Generative Models with Sinkhorn Divergences](https://arxiv.org/abs/1706.00292), Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018

examples/plot_OT_2D_samples.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"""
1111

1212
# Author: Remi Flamary <[email protected]>
13+
# Kilian Fatras <[email protected]>
1314
#
1415
# License: MIT License
1516

@@ -100,3 +101,28 @@
100101
pl.title('OT matrix Sinkhorn with samples')
101102

102103
pl.show()
104+
105+
106+
##############################################################################
107+
# Emprirical Sinkhorn
108+
# ----------------
109+
110+
#%% sinkhorn
111+
112+
# reg term
113+
lambd = 1e-3
114+
115+
Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd)
116+
117+
pl.figure(7)
118+
pl.imshow(Ges, interpolation='nearest')
119+
pl.title('OT matrix empirical sinkhorn')
120+
121+
pl.figure(8)
122+
ot.plot.plot2D_samples_mat(xs, xt, Ges, color=[.5, .5, 1])
123+
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
124+
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
125+
pl.legend(loc=0)
126+
pl.title('OT matrix Sinkhorn from samples')
127+
128+
pl.show()

ot/bregman.py

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55

66
# Author: Remi Flamary <[email protected]>
77
# Nicolas Courty <[email protected]>
8+
# Kilian Fatras <[email protected]>
89
#
910
# License: MIT License
1011

1112
import numpy as np
13+
from .utils import unif, dist
1214

1315

1416
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
@@ -1296,3 +1298,302 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
12961298
return np.sum(K0, axis=1), log
12971299
else:
12981300
return np.sum(K0, axis=1)
1301+
1302+
1303+
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
1304+
'''
1305+
Solve the entropic regularization optimal transport problem and return the
1306+
OT matrix from empirical data
1307+
1308+
The function solves the following optimization problem:
1309+
1310+
.. math::
1311+
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
1312+
1313+
s.t. \gamma 1 = a
1314+
1315+
\gamma^T 1= b
1316+
1317+
\gamma\geq 0
1318+
where :
1319+
1320+
- :math:`M` is the (ns,nt) metric cost matrix
1321+
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
1322+
- :math:`a` and :math:`b` are source and target weights (sum to 1)
1323+
1324+
1325+
Parameters
1326+
----------
1327+
X_s : np.ndarray (ns, d)
1328+
samples in the source domain
1329+
X_t : np.ndarray (nt, d)
1330+
samples in the target domain
1331+
reg : float
1332+
Regularization term >0
1333+
a : np.ndarray (ns,)
1334+
samples weights in the source domain
1335+
b : np.ndarray (nt,)
1336+
samples weights in the target domain
1337+
numItermax : int, optional
1338+
Max number of iterations
1339+
stopThr : float, optional
1340+
Stop threshol on error (>0)
1341+
verbose : bool, optional
1342+
Print information along iterations
1343+
log : bool, optional
1344+
record log if True
1345+
1346+
1347+
Returns
1348+
-------
1349+
gamma : (ns x nt) ndarray
1350+
Regularized optimal transportation matrix for the given parameters
1351+
log : dict
1352+
log dictionary return only if log==True in parameters
1353+
1354+
Examples
1355+
--------
1356+
1357+
>>> n_s = 2
1358+
>>> n_t = 2
1359+
>>> reg = 0.1
1360+
>>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
1361+
>>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
1362+
>>> emp_sinkhorn = empirical_sinkhorn(X_s, X_t, reg, verbose=False)
1363+
>>> print(emp_sinkhorn)
1364+
>>> [[4.99977301e-01 2.26989344e-05]
1365+
[2.26989344e-05 4.99977301e-01]]
1366+
1367+
1368+
References
1369+
----------
1370+
1371+
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
1372+
1373+
.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
1374+
1375+
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
1376+
'''
1377+
1378+
if a is None:
1379+
a = unif(np.shape(X_s)[0])
1380+
if b is None:
1381+
b = unif(np.shape(X_t)[0])
1382+
1383+
M = dist(X_s, X_t, metric=metric)
1384+
1385+
if log:
1386+
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
1387+
return pi, log
1388+
else:
1389+
pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
1390+
return pi
1391+
1392+
1393+
def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
1394+
'''
1395+
Solve the entropic regularization optimal transport problem from empirical
1396+
data and return the OT loss
1397+
1398+
1399+
The function solves the following optimization problem:
1400+
1401+
.. math::
1402+
W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
1403+
1404+
s.t. \gamma 1 = a
1405+
1406+
\gamma^T 1= b
1407+
1408+
\gamma\geq 0
1409+
where :
1410+
1411+
- :math:`M` is the (ns,nt) metric cost matrix
1412+
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
1413+
- :math:`a` and :math:`b` are source and target weights (sum to 1)
1414+
1415+
1416+
Parameters
1417+
----------
1418+
X_s : np.ndarray (ns, d)
1419+
samples in the source domain
1420+
X_t : np.ndarray (nt, d)
1421+
samples in the target domain
1422+
reg : float
1423+
Regularization term >0
1424+
a : np.ndarray (ns,)
1425+
samples weights in the source domain
1426+
b : np.ndarray (nt,)
1427+
samples weights in the target domain
1428+
numItermax : int, optional
1429+
Max number of iterations
1430+
stopThr : float, optional
1431+
Stop threshol on error (>0)
1432+
verbose : bool, optional
1433+
Print information along iterations
1434+
log : bool, optional
1435+
record log if True
1436+
1437+
1438+
Returns
1439+
-------
1440+
gamma : (ns x nt) ndarray
1441+
Regularized optimal transportation matrix for the given parameters
1442+
log : dict
1443+
log dictionary return only if log==True in parameters
1444+
1445+
Examples
1446+
--------
1447+
1448+
>>> n_s = 2
1449+
>>> n_t = 2
1450+
>>> reg = 0.1
1451+
>>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
1452+
>>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
1453+
>>> loss_sinkhorn = empirical_sinkhorn2(X_s, X_t, reg, verbose=False)
1454+
>>> print(loss_sinkhorn)
1455+
>>> [4.53978687e-05]
1456+
1457+
1458+
References
1459+
----------
1460+
1461+
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
1462+
1463+
.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
1464+
1465+
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
1466+
'''
1467+
1468+
if a is None:
1469+
a = unif(np.shape(X_s)[0])
1470+
if b is None:
1471+
b = unif(np.shape(X_t)[0])
1472+
1473+
M = dist(X_s, X_t, metric=metric)
1474+
1475+
if log:
1476+
sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
1477+
return sinkhorn_loss, log
1478+
else:
1479+
sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
1480+
return sinkhorn_loss
1481+
1482+
1483+
def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
1484+
'''
1485+
Compute the sinkhorn divergence loss from empirical data
1486+
1487+
The function solves the following optimization problems and return the
1488+
sinkhorn divergence :math:`S`:
1489+
1490+
.. math::
1491+
1492+
W &= \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
1493+
1494+
W_a &= \min_{\gamma_a} <\gamma_a,M_a>_F + reg\cdot\Omega(\gamma_a)
1495+
1496+
W_b &= \min_{\gamma_b} <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b)
1497+
1498+
S &= W - 1/2 * (W_a + W_b)
1499+
1500+
.. math::
1501+
s.t. \gamma 1 = a
1502+
1503+
\gamma^T 1= b
1504+
1505+
\gamma\geq 0
1506+
1507+
\gamma_a 1 = a
1508+
1509+
\gamma_a^T 1= a
1510+
1511+
\gamma_a\geq 0
1512+
1513+
\gamma_b 1 = b
1514+
1515+
\gamma_b^T 1= b
1516+
1517+
\gamma_b\geq 0
1518+
where :
1519+
1520+
- :math:`M` (resp. :math:`M_a, M_b`) is the (ns,nt) metric cost matrix (resp (ns, ns) and (nt, nt))
1521+
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
1522+
- :math:`a` and :math:`b` are source and target weights (sum to 1)
1523+
1524+
1525+
Parameters
1526+
----------
1527+
X_s : np.ndarray (ns, d)
1528+
samples in the source domain
1529+
X_t : np.ndarray (nt, d)
1530+
samples in the target domain
1531+
reg : float
1532+
Regularization term >0
1533+
a : np.ndarray (ns,)
1534+
samples weights in the source domain
1535+
b : np.ndarray (nt,)
1536+
samples weights in the target domain
1537+
numItermax : int, optional
1538+
Max number of iterations
1539+
stopThr : float, optional
1540+
Stop threshol on error (>0)
1541+
verbose : bool, optional
1542+
Print information along iterations
1543+
log : bool, optional
1544+
record log if True
1545+
1546+
1547+
Returns
1548+
-------
1549+
gamma : (ns x nt) ndarray
1550+
Regularized optimal transportation matrix for the given parameters
1551+
log : dict
1552+
log dictionary return only if log==True in parameters
1553+
1554+
Examples
1555+
--------
1556+
1557+
>>> n_s = 2
1558+
>>> n_t = 4
1559+
>>> reg = 0.1
1560+
>>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
1561+
>>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
1562+
>>> emp_sinkhorn_div = empirical_sinkhorn_divergence(X_s, X_t, reg)
1563+
>>> print(emp_sinkhorn_div)
1564+
>>> [2.99977435]
1565+
1566+
1567+
References
1568+
----------
1569+
1570+
.. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
1571+
'''
1572+
if log:
1573+
sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
1574+
1575+
sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
1576+
1577+
sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
1578+
1579+
sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
1580+
1581+
log = {}
1582+
log['sinkhorn_loss_ab'] = sinkhorn_loss_ab
1583+
log['sinkhorn_loss_a'] = sinkhorn_loss_a
1584+
log['sinkhorn_loss_b'] = sinkhorn_loss_b
1585+
log['log_sinkhorn_ab'] = log_ab
1586+
log['log_sinkhorn_a'] = log_a
1587+
log['log_sinkhorn_b'] = log_b
1588+
1589+
return max(0, sinkhorn_div), log
1590+
1591+
else:
1592+
sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
1593+
1594+
sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
1595+
1596+
sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
1597+
1598+
sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
1599+
return max(0, sinkhorn_div)

ot/stochastic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,11 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
348348
349349
.. math::
350350
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
351+
351352
s.t. \gamma 1 = a
353+
352354
\gamma^T 1= b
355+
353356
\gamma \geq 0
354357
355358
Where :

0 commit comments

Comments
 (0)