Skip to content

Commit 818c7ac

Browse files
eddarddrflamary
andauthored
[MRG] Free support Sinkhorn barycenters (#387)
* Adding function for computing Sinkhorn Free Support barycenters * Adding exampel on Free Support Sinkhorn Barycenter * Fixing typo on free support sinkhorn barycenter example * Adding info on new Free Support Barycenter solver * Removing extra line so that code follows pep8 * Fixing issues with pep8 in example * Correcting issues with pep8 standards * Adding tests for free support sinkhorn barycenter * Adding section on Sinkhorn barycenter to the example * Changing distributions for the Sinkhorn barycenter example * Removing file that should not be on the last commit * Adding PR number to REALEASES.md * Adding new contributors * Update CONTRIBUTORS.md Co-authored-by: Rémi Flamary <[email protected]>
1 parent 7c2a952 commit 818c7ac

File tree

6 files changed

+324
-3
lines changed

6 files changed

+324
-3
lines changed

CONTRIBUTORS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ The contributors to this library are:
3939
* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning)
4040
* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters)
4141
* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)
42+
* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
4243

4344
## Acknowledgments
4445

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#### New features
66

77
- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
8+
- Added Free Support Sinkhorn Barycenter + example (PR #387)
89

910
#### Closed issues
1011

examples/barycenters/plot_free_support_barycenter.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
2D free support Wasserstein barycenters of distributions
55
========================================================
66
7-
Illustration of 2D Wasserstein barycenters if distributions are weighted
7+
Illustration of 2D Wasserstein and Sinkhorn barycenters if distributions are weighted
88
sum of diracs.
99
1010
"""
1111

1212
# Authors: Vivien Seguy <[email protected]>
1313
# Rémi Flamary <[email protected]>
14+
# Eduardo Fernandes Montesuma <[email protected]>
1415
#
1516
# License: MIT License
1617

@@ -48,7 +49,7 @@
4849

4950

5051
# %%
51-
# Compute free support barycenter
52+
# Compute free support Wasserstein barycenter
5253
# -------------------------------
5354

5455
k = 200 # number of Diracs of the barycenter
@@ -58,7 +59,28 @@
5859
X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)
5960

6061
# %%
61-
# Plot the barycenter
62+
# Plot the Wasserstein barycenter
63+
# ---------
64+
65+
pl.figure(2, (8, 3))
66+
pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
67+
pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
68+
pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter')
69+
pl.title('Data measures and their barycenter')
70+
pl.legend(loc="lower right")
71+
pl.show()
72+
73+
# %%
74+
# Compute free support Sinkhorn barycenter
75+
76+
k = 200 # number of Diracs of the barycenter
77+
X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
78+
b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)
79+
80+
X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, 20, b, numItermax=15)
81+
82+
# %%
83+
# Plot the Wasserstein barycenter
6284
# ---------
6385

6486
pl.figure(2, (8, 3))
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
========================================================
4+
2D free support Sinkhorn barycenters of distributions
5+
========================================================
6+
7+
Illustration of Sinkhorn barycenter calculation between empirical distributions understood as point clouds
8+
9+
"""
10+
11+
# Authors: Eduardo Fernandes Montesuma <[email protected]>
12+
#
13+
# License: MIT License
14+
15+
import numpy as np
16+
import matplotlib.pyplot as plt
17+
import ot
18+
19+
# %%
20+
# General Parameters
21+
# ------------------
22+
reg = 1e-2 # Entropic Regularization
23+
numItermax = 20 # Maximum number of iterations for the Barycenter algorithm
24+
numInnerItermax = 50 # Maximum number of sinkhorn iterations
25+
n_samples = 200
26+
27+
# %%
28+
# Generate Data
29+
# -------------
30+
31+
X1 = np.random.randn(200, 2)
32+
X2 = 2 * np.concatenate([
33+
np.concatenate([- np.ones([50, 1]), np.linspace(-1, 1, 50)[:, None]], axis=1),
34+
np.concatenate([np.linspace(-1, 1, 50)[:, None], np.ones([50, 1])], axis=1),
35+
np.concatenate([np.ones([50, 1]), np.linspace(1, -1, 50)[:, None]], axis=1),
36+
np.concatenate([np.linspace(1, -1, 50)[:, None], - np.ones([50, 1])], axis=1),
37+
], axis=0)
38+
X3 = np.random.randn(200, 2)
39+
X3 = 2 * (X3 / np.linalg.norm(X3, axis=1)[:, None])
40+
X4 = np.random.multivariate_normal(np.array([0, 0]), np.array([[1., 0.5], [0.5, 1.]]), size=200)
41+
42+
a1, a2, a3, a4 = ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1))
43+
44+
# %%
45+
# Inspect generated distributions
46+
# -------------------------------
47+
48+
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
49+
50+
axes[0].scatter(x=X1[:, 0], y=X1[:, 1], c='steelblue', edgecolor='k')
51+
axes[1].scatter(x=X2[:, 0], y=X2[:, 1], c='steelblue', edgecolor='k')
52+
axes[2].scatter(x=X3[:, 0], y=X3[:, 1], c='steelblue', edgecolor='k')
53+
axes[3].scatter(x=X4[:, 0], y=X4[:, 1], c='steelblue', edgecolor='k')
54+
55+
axes[0].set_xlim([-3, 3])
56+
axes[0].set_ylim([-3, 3])
57+
axes[0].set_title('Distribution 1')
58+
59+
axes[1].set_xlim([-3, 3])
60+
axes[1].set_ylim([-3, 3])
61+
axes[1].set_title('Distribution 2')
62+
63+
axes[2].set_xlim([-3, 3])
64+
axes[2].set_ylim([-3, 3])
65+
axes[2].set_title('Distribution 3')
66+
67+
axes[3].set_xlim([-3, 3])
68+
axes[3].set_ylim([-3, 3])
69+
axes[3].set_title('Distribution 4')
70+
71+
plt.tight_layout()
72+
plt.show()
73+
74+
# %%
75+
# Interpolating Empirical Distributions
76+
# -------------------------------------
77+
78+
fig = plt.figure(figsize=(10, 10))
79+
80+
weights = np.array([
81+
[3 / 3, 0 / 3],
82+
[2 / 3, 1 / 3],
83+
[1 / 3, 2 / 3],
84+
[0 / 3, 3 / 3],
85+
]).astype(np.float32)
86+
87+
for k in range(4):
88+
XB_init = np.random.randn(n_samples, 2)
89+
XB = ot.bregman.free_support_sinkhorn_barycenter(
90+
measures_locations=[X1, X2],
91+
measures_weights=[a1, a2],
92+
weights=weights[k],
93+
X_init=XB_init,
94+
reg=reg,
95+
numItermax=numItermax,
96+
numInnerItermax=numInnerItermax
97+
)
98+
ax = plt.subplot2grid((4, 4), (0, k))
99+
ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
100+
ax.set_xlim([-3, 3])
101+
ax.set_ylim([-3, 3])
102+
103+
for k in range(1, 4, 1):
104+
XB_init = np.random.randn(n_samples, 2)
105+
XB = ot.bregman.free_support_sinkhorn_barycenter(
106+
measures_locations=[X1, X3],
107+
measures_weights=[a1, a2],
108+
weights=weights[k],
109+
X_init=XB_init,
110+
reg=reg,
111+
numItermax=numItermax,
112+
numInnerItermax=numInnerItermax
113+
)
114+
ax = plt.subplot2grid((4, 4), (k, 0))
115+
ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
116+
ax.set_xlim([-3, 3])
117+
ax.set_ylim([-3, 3])
118+
119+
for k in range(1, 4, 1):
120+
XB_init = np.random.randn(n_samples, 2)
121+
XB = ot.bregman.free_support_sinkhorn_barycenter(
122+
measures_locations=[X3, X4],
123+
measures_weights=[a1, a2],
124+
weights=weights[k],
125+
X_init=XB_init,
126+
reg=reg,
127+
numItermax=numItermax,
128+
numInnerItermax=numInnerItermax
129+
)
130+
ax = plt.subplot2grid((4, 4), (3, k))
131+
ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
132+
ax.set_xlim([-3, 3])
133+
ax.set_ylim([-3, 3])
134+
135+
for k in range(1, 3, 1):
136+
XB_init = np.random.randn(n_samples, 2)
137+
XB = ot.bregman.free_support_sinkhorn_barycenter(
138+
measures_locations=[X2, X4],
139+
measures_weights=[a1, a2],
140+
weights=weights[k],
141+
X_init=XB_init,
142+
reg=reg,
143+
numItermax=numItermax,
144+
numInnerItermax=numInnerItermax
145+
)
146+
ax = plt.subplot2grid((4, 4), (k, 3))
147+
ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
148+
ax.set_xlim([-3, 3])
149+
ax.set_ylim([-3, 3])
150+
151+
plt.show()

ot/bregman.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,6 +1540,126 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
15401540
return geometricBar(weights, UKv)
15411541

15421542

1543+
def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg, b=None, weights=None,
1544+
numItermax=100, numInnerItermax=1000, stopThr=1e-7, verbose=False, log=None,
1545+
**kwargs):
1546+
r"""
1547+
Solves the free support (locations of the barycenters are optimized, not the weights) regularized Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Sinkhorn divergence), formally:
1548+
1549+
.. math::
1550+
\min_\mathbf{X} \quad \sum_{i=1}^N w_i W_{reg}^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i)
1551+
1552+
where :
1553+
1554+
- :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one
1555+
- `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex)
1556+
- `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations
1557+
- :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter
1558+
1559+
This problem is considered in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).
1560+
There are two differences with the following codes:
1561+
1562+
- we do not optimize over the weights
1563+
- we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in
1564+
:ref:`[20] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete
1565+
implementation of the fixed-point algorithm of
1566+
:ref:`[43] <references-free-support-barycenter>` proposed in the continuous setting.
1567+
- at each iteration, instead of solving an exact OT problem, we use the Sinkhorn algorithm for calculating the
1568+
transport plan in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).
1569+
1570+
Parameters
1571+
----------
1572+
measures_locations : list of N (k_i,d) array-like
1573+
The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space
1574+
(:math:`k_i` can be different for each element of the list)
1575+
measures_weights : list of N (k_i,) array-like
1576+
Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one
1577+
representing the weights of each discrete input measure
1578+
1579+
X_init : (k,d) array-like
1580+
Initialization of the support locations (on `k` atoms) of the barycenter
1581+
reg : float
1582+
Regularization term >0
1583+
b : (k,) array-like
1584+
Initialization of the weights of the barycenter (non-negatives, sum to 1)
1585+
weights : (N,) array-like
1586+
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
1587+
1588+
numItermax : int, optional
1589+
Max number of iterations
1590+
numInnerItermax : int, optional
1591+
Max number of iterations when calculating the transport plans with Sinkhorn
1592+
stopThr : float, optional
1593+
Stop threshold on error (>0)
1594+
verbose : bool, optional
1595+
Print information along iterations
1596+
log : bool, optional
1597+
record log if True
1598+
1599+
Returns
1600+
-------
1601+
X : (k,d) array-like
1602+
Support locations (on k atoms) of the barycenter
1603+
1604+
See Also
1605+
--------
1606+
ot.bregman.sinkhorn : Entropic regularized OT solver
1607+
ot.lp.free_support_barycenter : Barycenter solver based on Linear Programming
1608+
1609+
.. _references-free-support-barycenter:
1610+
References
1611+
----------
1612+
.. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
1613+
1614+
.. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
1615+
1616+
"""
1617+
nx = get_backend(*measures_locations, *measures_weights, X_init)
1618+
1619+
iter_count = 0
1620+
1621+
N = len(measures_locations)
1622+
k = X_init.shape[0]
1623+
d = X_init.shape[1]
1624+
if b is None:
1625+
b = nx.ones((k,), type_as=X_init) / k
1626+
if weights is None:
1627+
weights = nx.ones((N,), type_as=X_init) / N
1628+
1629+
X = X_init
1630+
1631+
log_dict = {}
1632+
displacement_square_norms = []
1633+
1634+
displacement_square_norm = stopThr + 1.
1635+
1636+
while (displacement_square_norm > stopThr and iter_count < numItermax):
1637+
1638+
T_sum = nx.zeros((k, d), type_as=X_init)
1639+
1640+
for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
1641+
M_i = dist(X, measure_locations_i)
1642+
T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg, numItermax=numInnerItermax, **kwargs)
1643+
T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i)
1644+
1645+
displacement_square_norm = nx.sum((T_sum - X) ** 2)
1646+
if log:
1647+
displacement_square_norms.append(displacement_square_norm)
1648+
1649+
X = T_sum
1650+
1651+
if verbose:
1652+
print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm)
1653+
1654+
iter_count += 1
1655+
1656+
if log:
1657+
log_dict['displacement_square_norms'] = displacement_square_norms
1658+
return X, log_dict
1659+
else:
1660+
return X
1661+
1662+
15431663
def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000,
15441664
stopThr=1e-4, verbose=False, log=False, warn=True):
15451665
r"""Compute the entropic wasserstein barycenter in log-domain

test/test_bregman.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Author: Remi Flamary <[email protected]>
44
# Kilian Fatras <[email protected]>
55
# Quang Huy Tran <[email protected]>
6+
# Eduardo Fernandes Montesuma <[email protected]>
67
#
78
# License: MIT License
89

@@ -490,6 +491,31 @@ def test_barycenter(nx, method, verbose, warn):
490491
ot.bregman.barycenter(A_nx, M_nx, reg, log=True)
491492

492493

494+
def test_free_support_sinkhorn_barycenter():
495+
measures_locations = [
496+
np.array([-1.]).reshape((1, 1)), # First dirac support
497+
np.array([1.]).reshape((1, 1)) # Second dirac support
498+
]
499+
500+
measures_weights = [
501+
np.array([1.]), # First dirac sample weights
502+
np.array([1.]) # Second dirac sample weights
503+
]
504+
505+
# Barycenter initialization
506+
X_init = np.array([-12.]).reshape((1, 1))
507+
508+
# Obvious barycenter locations. Take a look on test_ot.py, test_free_support_barycenter
509+
bar_locations = np.array([0.]).reshape((1, 1))
510+
511+
# Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization
512+
# term to 1, but this should be, in general, fine-tuned to the problem.
513+
X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg=1)
514+
515+
# Verifies if calculated barycenter matches ground-truth
516+
np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
517+
518+
493519
@pytest.mark.parametrize("method, verbose, warn",
494520
product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"],
495521
[True, False], [True, False]))

0 commit comments

Comments
 (0)