Skip to content

Commit e70d542

Browse files
authored
Merge pull request #23 from rflamary/gromov
Gromov-Wasserstein distance
2 parents a53ede9 + c7eef9d commit e70d542

File tree

10 files changed

+855
-2
lines changed

10 files changed

+855
-2
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ It provides the following solvers:
1616
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
1717
* Joint OT matrix and mapping estimation [8].
1818
* Wasserstein Discriminant Analysis [11] (requires autograd + pymanopt).
19-
19+
* Gromov-Wasserstein distances and barycenters [12]
2020

2121
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
2222

@@ -184,3 +184,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
184184
[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). [Scaling algorithms for unbalanced transport problems](https://arxiv.org/pdf/1607.05816.pdf). arXiv preprint arXiv:1607.05816.
185185

186186
[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). [Wasserstein Discriminant Analysis](https://arxiv.org/pdf/1608.08063.pdf). arXiv preprint arXiv:1608.08063.
187+
188+
[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, [Gromov-Wasserstein averaging of kernel and distance matrices](http://proceedings.mlr.press/v48/peyre16.html) International Conference on Machine Learning (ICML). 2016.

data/cross.png

230 Bytes
Loading

data/square.png

168 Bytes
Loading

data/star.png

225 Bytes
Loading

data/triangle.png

254 Bytes
Loading

examples/plot_gromov.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
==========================
4+
Gromov-Wasserstein example
5+
==========================
6+
This example is designed to show how to use the Gromov-Wassertsein distance
7+
computation in POT.
8+
"""
9+
10+
# Author: Erwan Vautier <[email protected]>
11+
# Nicolas Courty <[email protected]>
12+
#
13+
# License: MIT License
14+
15+
import scipy as sp
16+
import numpy as np
17+
import matplotlib.pylab as pl
18+
19+
import ot
20+
21+
22+
"""
23+
Sample two Gaussian distributions (2D and 3D)
24+
=============================================
25+
The Gromov-Wasserstein distance allows to compute distances with samples that
26+
do not belong to the same metric space. For demonstration purpose, we sample
27+
two Gaussian distributions in 2- and 3-dimensional spaces.
28+
"""
29+
30+
n_samples = 30 # nb samples
31+
32+
mu_s = np.array([0, 0])
33+
cov_s = np.array([[1, 0], [0, 1]])
34+
35+
mu_t = np.array([4, 4, 4])
36+
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
37+
38+
39+
xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
40+
P = sp.linalg.sqrtm(cov_t)
41+
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
42+
43+
44+
"""
45+
Plotting the distributions
46+
==========================
47+
"""
48+
fig = pl.figure()
49+
ax1 = fig.add_subplot(121)
50+
ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
51+
ax2 = fig.add_subplot(122, projection='3d')
52+
ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
53+
pl.show()
54+
55+
56+
"""
57+
Compute distance kernels, normalize them and then display
58+
=========================================================
59+
"""
60+
61+
C1 = sp.spatial.distance.cdist(xs, xs)
62+
C2 = sp.spatial.distance.cdist(xt, xt)
63+
64+
C1 /= C1.max()
65+
C2 /= C2.max()
66+
67+
pl.figure()
68+
pl.subplot(121)
69+
pl.imshow(C1)
70+
pl.subplot(122)
71+
pl.imshow(C2)
72+
pl.show()
73+
74+
"""
75+
Compute Gromov-Wasserstein plans and distance
76+
=============================================
77+
"""
78+
79+
p = ot.unif(n_samples)
80+
q = ot.unif(n_samples)
81+
82+
gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
83+
gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4)
84+
85+
print('Gromov-Wasserstein distances between the distribution: ' + str(gw_dist))
86+
87+
pl.figure()
88+
pl.imshow(gw, cmap='jet')
89+
pl.colorbar()
90+
pl.show()

examples/plot_gromov_barycenter.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
=====================================
4+
Gromov-Wasserstein Barycenter example
5+
=====================================
6+
This example is designed to show how to use the Gromov-Wasserstein distance
7+
computation in POT.
8+
"""
9+
10+
# Author: Erwan Vautier <[email protected]>
11+
# Nicolas Courty <[email protected]>
12+
#
13+
# License: MIT License
14+
15+
16+
import numpy as np
17+
import scipy as sp
18+
19+
import scipy.ndimage as spi
20+
import matplotlib.pylab as pl
21+
from sklearn import manifold
22+
from sklearn.decomposition import PCA
23+
24+
import ot
25+
26+
"""
27+
28+
Smacof MDS
29+
==========
30+
This function allows to find an embedding of points given a dissimilarity matrix
31+
that will be given by the output of the algorithm
32+
"""
33+
34+
35+
def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
36+
"""
37+
Returns an interpolated point cloud following the dissimilarity matrix C
38+
using SMACOF multidimensional scaling (MDS) in specific dimensionned
39+
target space
40+
41+
Parameters
42+
----------
43+
C : ndarray, shape (ns, ns)
44+
dissimilarity matrix
45+
dim : int
46+
dimension of the targeted space
47+
max_iter : int
48+
Maximum number of iterations of the SMACOF algorithm for a single run
49+
eps : float
50+
relative tolerance w.r.t stress to declare converge
51+
52+
Returns
53+
-------
54+
npos : ndarray, shape (R, dim)
55+
Embedded coordinates of the interpolated point cloud (defined with
56+
one isometry)
57+
"""
58+
59+
rng = np.random.RandomState(seed=3)
60+
61+
mds = manifold.MDS(
62+
dim,
63+
max_iter=max_iter,
64+
eps=1e-9,
65+
dissimilarity='precomputed',
66+
n_init=1)
67+
pos = mds.fit(C).embedding_
68+
69+
nmds = manifold.MDS(
70+
2,
71+
max_iter=max_iter,
72+
eps=1e-9,
73+
dissimilarity="precomputed",
74+
random_state=rng,
75+
n_init=1)
76+
npos = nmds.fit_transform(C, init=pos)
77+
78+
return npos
79+
80+
81+
"""
82+
Data preparation
83+
================
84+
The four distributions are constructed from 4 simple images
85+
"""
86+
87+
88+
def im2mat(I):
89+
"""Converts and image to matrix (one pixel per line)"""
90+
return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
91+
92+
93+
square = spi.imread('../data/square.png').astype(np.float64)[:, :, 2] / 256
94+
cross = spi.imread('../data/cross.png').astype(np.float64)[:, :, 2] / 256
95+
triangle = spi.imread('../data/triangle.png').astype(np.float64)[:, :, 2] / 256
96+
star = spi.imread('../data/star.png').astype(np.float64)[:, :, 2] / 256
97+
98+
shapes = [square, cross, triangle, star]
99+
100+
S = 4
101+
xs = [[] for i in range(S)]
102+
103+
104+
for nb in range(4):
105+
for i in range(8):
106+
for j in range(8):
107+
if shapes[nb][i, j] < 0.95:
108+
xs[nb].append([j, 8 - i])
109+
110+
xs = np.array([np.array(xs[0]), np.array(xs[1]),
111+
np.array(xs[2]), np.array(xs[3])])
112+
113+
114+
"""
115+
Barycenter computation
116+
======================
117+
The four distributions are constructed from 4 simple images
118+
"""
119+
ns = [len(xs[s]) for s in range(S)]
120+
n_samples = 30
121+
122+
"""Compute all distances matrices for the four shapes"""
123+
Cs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]
124+
Cs = [cs / cs.max() for cs in Cs]
125+
126+
ps = [ot.unif(ns[s]) for s in range(S)]
127+
p = ot.unif(n_samples)
128+
129+
130+
lambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]
131+
132+
Ct01 = [0 for i in range(2)]
133+
for i in range(2):
134+
Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]],
135+
[ps[0], ps[1]
136+
], p, lambdast[i], 'square_loss', 5e-4,
137+
max_iter=100, stopThr=1e-3)
138+
139+
Ct02 = [0 for i in range(2)]
140+
for i in range(2):
141+
Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]],
142+
[ps[0], ps[2]
143+
], p, lambdast[i], 'square_loss', 5e-4,
144+
max_iter=100, stopThr=1e-3)
145+
146+
Ct13 = [0 for i in range(2)]
147+
for i in range(2):
148+
Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]],
149+
[ps[1], ps[3]
150+
], p, lambdast[i], 'square_loss', 5e-4,
151+
max_iter=100, stopThr=1e-3)
152+
153+
Ct23 = [0 for i in range(2)]
154+
for i in range(2):
155+
Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]],
156+
[ps[2], ps[3]
157+
], p, lambdast[i], 'square_loss', 5e-4,
158+
max_iter=100, stopThr=1e-3)
159+
160+
"""
161+
Visualization
162+
=============
163+
"""
164+
165+
"""The PCA helps in getting consistency between the rotations"""
166+
167+
clf = PCA(n_components=2)
168+
npos = [0, 0, 0, 0]
169+
npos = [smacof_mds(Cs[s], 2) for s in range(S)]
170+
171+
npost01 = [0, 0]
172+
npost01 = [smacof_mds(Ct01[s], 2) for s in range(2)]
173+
npost01 = [clf.fit_transform(npost01[s]) for s in range(2)]
174+
175+
npost02 = [0, 0]
176+
npost02 = [smacof_mds(Ct02[s], 2) for s in range(2)]
177+
npost02 = [clf.fit_transform(npost02[s]) for s in range(2)]
178+
179+
npost13 = [0, 0]
180+
npost13 = [smacof_mds(Ct13[s], 2) for s in range(2)]
181+
npost13 = [clf.fit_transform(npost13[s]) for s in range(2)]
182+
183+
npost23 = [0, 0]
184+
npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]
185+
npost23 = [clf.fit_transform(npost23[s]) for s in range(2)]
186+
187+
188+
fig = pl.figure(figsize=(10, 10))
189+
190+
ax1 = pl.subplot2grid((4, 4), (0, 0))
191+
pl.xlim((-1, 1))
192+
pl.ylim((-1, 1))
193+
ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')
194+
195+
ax2 = pl.subplot2grid((4, 4), (0, 1))
196+
pl.xlim((-1, 1))
197+
pl.ylim((-1, 1))
198+
ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')
199+
200+
ax3 = pl.subplot2grid((4, 4), (0, 2))
201+
pl.xlim((-1, 1))
202+
pl.ylim((-1, 1))
203+
ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')
204+
205+
ax4 = pl.subplot2grid((4, 4), (0, 3))
206+
pl.xlim((-1, 1))
207+
pl.ylim((-1, 1))
208+
ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')
209+
210+
ax5 = pl.subplot2grid((4, 4), (1, 0))
211+
pl.xlim((-1, 1))
212+
pl.ylim((-1, 1))
213+
ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')
214+
215+
ax6 = pl.subplot2grid((4, 4), (1, 3))
216+
pl.xlim((-1, 1))
217+
pl.ylim((-1, 1))
218+
ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')
219+
220+
ax7 = pl.subplot2grid((4, 4), (2, 0))
221+
pl.xlim((-1, 1))
222+
pl.ylim((-1, 1))
223+
ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')
224+
225+
ax8 = pl.subplot2grid((4, 4), (2, 3))
226+
pl.xlim((-1, 1))
227+
pl.ylim((-1, 1))
228+
ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')
229+
230+
ax9 = pl.subplot2grid((4, 4), (3, 0))
231+
pl.xlim((-1, 1))
232+
pl.ylim((-1, 1))
233+
ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')
234+
235+
ax10 = pl.subplot2grid((4, 4), (3, 1))
236+
pl.xlim((-1, 1))
237+
pl.ylim((-1, 1))
238+
ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')
239+
240+
ax11 = pl.subplot2grid((4, 4), (3, 2))
241+
pl.xlim((-1, 1))
242+
pl.ylim((-1, 1))
243+
ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')
244+
245+
ax12 = pl.subplot2grid((4, 4), (3, 3))
246+
pl.xlim((-1, 1))
247+
pl.ylim((-1, 1))
248+
ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')

ot/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
# Author: Remi Flamary <[email protected]>
8+
# Nicolas Courty <[email protected]>
89
#
910
# License: MIT License
1011

@@ -17,11 +18,13 @@
1718
from . import datasets
1819
from . import plot
1920
from . import da
21+
from . import gromov
2022

2123
# OT functions
2224
from .lp import emd, emd2
2325
from .bregman import sinkhorn, sinkhorn2, barycenter
2426
from .da import sinkhorn_lpl1_mm
27+
from .gromov import gromov_wasserstein, gromov_wasserstein2
2528

2629
# utils functions
2730
from .utils import dist, unif, tic, toc, toq
@@ -30,4 +33,5 @@
3033

3134
__all__ = ["emd", "emd2", "sinkhorn", "sinkhorn2", "utils", 'datasets',
3235
'bregman', 'lp', 'plot', 'tic', 'toc', 'toq',
33-
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']
36+
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
37+
'gromov_wasserstein','gromov_wasserstein2']

0 commit comments

Comments
 (0)