Skip to content

Commit 84c2723

Browse files
Nicolas CourtyNicolas Courty
authored andcommitted
Corrections on Gromov
1 parent 24784ed commit 84c2723

File tree

3 files changed

+42
-24
lines changed

3 files changed

+42
-24
lines changed

examples/plot_gromov.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
"""
2323
Sample two Gaussian distributions (2D and 3D)
2424
=============================================
25-
The Gromov-Wasserstein distance allows to compute distances with samples that
26-
do not belong to the same metric space. For demonstration purpose, we sample
25+
The Gromov-Wasserstein distance allows to compute distances with samples that
26+
do not belong to the same metric space. For demonstration purpose, we sample
2727
two Gaussian distributions in 2- and 3-dimensional spaces.
2828
"""
2929

examples/plot_gromov_barycenter.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
=====================================
44
Gromov-Wasserstein Barycenter example
55
=====================================
6-
This example is designed to show how to use the Gromov-Wassertsein distance
6+
This example is designed to show how to use the Gromov-Wasserstein distance
77
computation in POT.
88
"""
99

@@ -34,8 +34,9 @@
3434

3535
def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
3636
"""
37-
Returns an interpolated point cloud following the dissimilarity matrix C using SMACOF
38-
multidimensional scaling (MDS) in specific dimensionned target space
37+
Returns an interpolated point cloud following the dissimilarity matrix C
38+
using SMACOF multidimensional scaling (MDS) in specific dimensionned
39+
target space
3940
4041
Parameters
4142
----------
@@ -51,7 +52,8 @@ def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
5152
Returns
5253
-------
5354
npos : ndarray, shape (R, dim)
54-
Embedded coordinates of the interpolated point cloud (defined with one isometry)
55+
Embedded coordinates of the interpolated point cloud (defined with
56+
one isometry)
5557
"""
5658

5759
rng = np.random.RandomState(seed=3)
@@ -88,10 +90,10 @@ def im2mat(I):
8890
return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
8991

9092

91-
square = spi.imread('../data/square.png').astype(np.float64)[:,:,2] / 256
92-
cross = spi.imread('../data/cross.png').astype(np.float64)[:,:,2] / 256
93-
triangle = spi.imread('../data/triangle.png').astype(np.float64)[:,:,2] / 256
94-
star = spi.imread('../data/star.png').astype(np.float64)[:,:,2] / 256
93+
square = spi.imread('../data/square.png').astype(np.float64)[:, :, 2] / 256
94+
cross = spi.imread('../data/cross.png').astype(np.float64)[:, :, 2] / 256
95+
triangle = spi.imread('../data/triangle.png').astype(np.float64)[:, :, 2] / 256
96+
star = spi.imread('../data/star.png').astype(np.float64)[:, :, 2] / 256
9597

9698
shapes = [square, cross, triangle, star]
9799

ot/gromov.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,9 @@ def tensor_kl_loss(C1, C2, T):
122122
123123
References
124124
----------
125-
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016.
125+
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
126+
"Gromov-Wasserstein averaging of kernel and distance matrices."
127+
International Conference on Machine Learning (ICML). 2016.
126128
127129
"""
128130

@@ -157,7 +159,8 @@ def update_square_loss(p, lambdas, T, Cs):
157159
----------
158160
p : ndarray, shape (N,)
159161
weights in the targeted barycenter
160-
lambdas : list of the S spaces' weights
162+
lambdas : list of float
163+
list of the S spaces' weights
161164
T : list of S np.ndarray(ns,N)
162165
the S Ts couplings calculated at each iteration
163166
Cs : list of S ndarray, shape(ns,ns)
@@ -168,7 +171,8 @@ def update_square_loss(p, lambdas, T, Cs):
168171
C : ndarray, shape (nt,nt)
169172
updated C matrix
170173
"""
171-
tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
174+
tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s])
175+
for s in range(len(T))])
172176
ppt = np.outer(p, p)
173177

174178
return np.divide(tmpsum, ppt)
@@ -194,13 +198,15 @@ def update_kl_loss(p, lambdas, T, Cs):
194198
C : ndarray, shape (ns,ns)
195199
updated C matrix
196200
"""
197-
tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
201+
tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s])
202+
for s in range(len(T))])
198203
ppt = np.outer(p, p)
199204

200205
return np.exp(np.divide(tmpsum, ppt))
201206

202207

203-
def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False):
208+
def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
209+
max_iter=1000, tol=1e-9, verbose=False, log=False):
204210
"""
205211
Returns the gromov-wasserstein coupling between the two measured similarity matrices
206212
@@ -276,7 +282,8 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9,
276282
T = sinkhorn(p, q, tens, epsilon)
277283

278284
if cpt % 10 == 0:
279-
# we can speed up the process by checking for the error only all the 10th iterations
285+
# we can speed up the process by checking for the error only all
286+
# the 10th iterations
280287
err = np.linalg.norm(T - Tprev)
281288

282289
if log:
@@ -296,7 +303,8 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9,
296303
return T
297304

298305

299-
def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False):
306+
def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
307+
max_iter=1000, tol=1e-9, verbose=False, log=False):
300308
"""
301309
Returns the gromov-wasserstein discrepancy between the two measured similarity matrices
302310
@@ -363,7 +371,8 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9
363371
return gw_dist
364372

365373

366-
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False):
374+
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
375+
max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
367376
"""
368377
Returns the gromov-wasserstein barycenters of S measured similarity matrices
369378
@@ -390,7 +399,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
390399
sample weights in the S spaces
391400
p : ndarray, shape(N,)
392401
weights in the targeted barycenter
393-
lambdas : list of the S spaces' weights
402+
lambdas : list of float
403+
list of the S spaces' weights
394404
L : tensor-matrix multiplication function based on specific loss function
395405
update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel
396406
with the S Ts couplings calculated at each iteration
@@ -404,6 +414,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
404414
Print information along iterations
405415
log : bool, optional
406416
record log if True
417+
init_C : bool, ndarray, shape(N,N)
418+
random initial value for the C matrix provided by user
407419
408420
Returns
409421
-------
@@ -416,10 +428,13 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
416428
Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
417429
lambdas = np.asarray(lambdas, dtype=np.float64)
418430

419-
# Initialization of C : random SPD matrix
420-
xalea = np.random.randn(N, 2)
421-
C = dist(xalea, xalea)
422-
C /= C.max()
431+
# Initialization of C : random SPD matrix (if not provided by user)
432+
if init_C is None:
433+
xalea = np.random.randn(N, 2)
434+
C = dist(xalea, xalea)
435+
C /= C.max()
436+
else:
437+
C = init_C
423438

424439
cpt = 0
425440
err = 1
@@ -438,7 +453,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
438453
C = update_kl_loss(p, lambdas, T, Cs)
439454

440455
if cpt % 10 == 0:
441-
# we can speed up the process by checking for the error only all the 10th iterations
456+
# we can speed up the process by checking for the error only all
457+
# the 10th iterations
442458
err = np.linalg.norm(C - Cprev)
443459
error.append(err)
444460

0 commit comments

Comments
 (0)