Skip to content

Commit 53e1115

Browse files
Nicolas CourtyNicolas Courty
authored andcommitted
docstrings + naming
1 parent f12322c commit 53e1115

File tree

3 files changed

+68
-60
lines changed

3 files changed

+68
-60
lines changed

examples/plot_gromov_barycenter.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,19 @@ def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
4545
dimension of the targeted space
4646
max_iter : int
4747
Maximum number of iterations of the SMACOF algorithm for a single run
48-
49-
eps : relative tolerance w.r.t stress to declare converge
48+
eps : float
49+
relative tolerance w.r.t stress to declare converge
5050
5151
5252
Returns
5353
-------
54-
npos : R**dim ndarray
54+
npos : ndarray, shape (R, dim)
5555
Embedded coordinates of the interpolated point cloud (defined with one isometry)
5656
5757
5858
"""
5959

60-
seed = np.random.RandomState(seed=3)
60+
rng = np.random.RandomState(seed=3)
6161

6262
mds = manifold.MDS(
6363
dim,
@@ -72,7 +72,7 @@ def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
7272
max_iter=max_iter,
7373
eps=1e-9,
7474
dissimilarity="precomputed",
75-
random_state=seed,
75+
random_state=rng,
7676
n_init=1)
7777
npos = nmds.fit_transform(C, init=pos)
7878

@@ -132,23 +132,31 @@ def im2mat(I):
132132

133133
Ct01 = [0 for i in range(2)]
134134
for i in range(2):
135-
Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]], [
136-
ps[0], ps[1]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
135+
Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]],
136+
[ps[0], ps[1]
137+
], p, lambdast[i], 'square_loss', 5e-4,
138+
max_iter=100, stopThr=1e-3)
137139

138140
Ct02 = [0 for i in range(2)]
139141
for i in range(2):
140-
Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]], [
141-
ps[0], ps[2]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
142+
Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]],
143+
[ps[0], ps[2]
144+
], p, lambdast[i], 'square_loss', 5e-4,
145+
max_iter=100, stopThr=1e-3)
142146

143147
Ct13 = [0 for i in range(2)]
144148
for i in range(2):
145-
Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]], [
146-
ps[1], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
149+
Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]],
150+
[ps[1], ps[3]
151+
], p, lambdast[i], 'square_loss', 5e-4,
152+
max_iter=100, stopThr=1e-3)
147153

148154
Ct23 = [0 for i in range(2)]
149155
for i in range(2):
150-
Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]], [
151-
ps[2], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
156+
Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]],
157+
[ps[2], ps[3]
158+
], p, lambdast[i], 'square_loss', 5e-4,
159+
max_iter=100, stopThr=1e-3)
152160

153161
"""
154162
Visualization

ot/gromov.py

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,13 @@ def tensor_square_loss(C1, C2, T):
5858
Metric cost matrix in the source space
5959
C2 : ndarray, shape (nt, nt)
6060
Metric costfr matrix in the target space
61-
T : np.ndarray(ns,nt)
61+
T : ndarray, shape (ns, nt)
6262
Coupling between source and target spaces
6363
6464
6565
Returns
6666
-------
67-
tens : (ns*nt) ndarray
67+
tens : ndarray, shape (ns, nt)
6868
\mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
6969
7070
@@ -89,7 +89,7 @@ def h2(b):
8989
tens = -np.dot(h1(C1), T).dot(h2(C2).T)
9090
tens -= tens.min()
9191

92-
return np.array(tens)
92+
return tens
9393

9494

9595
def tensor_kl_loss(C1, C2, T):
@@ -116,13 +116,13 @@ def tensor_kl_loss(C1, C2, T):
116116
Metric cost matrix in the source space
117117
C2 : ndarray, shape (nt, nt)
118118
Metric costfr matrix in the target space
119-
T : np.ndarray(ns,nt)
119+
T : ndarray, shape (ns, nt)
120120
Coupling between source and target spaces
121121
122122
123123
Returns
124124
-------
125-
tens : (ns*nt) ndarray
125+
tens : ndarray, shape (ns, nt)
126126
\mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
127127
128128
References
@@ -151,34 +151,36 @@ def h2(b):
151151
tens = -np.dot(h1(C1), T).dot(h2(C2).T)
152152
tens -= tens.min()
153153

154-
return np.array(tens)
154+
return tens
155155

156156

157157
def update_square_loss(p, lambdas, T, Cs):
158158
"""
159-
Updates C according to the L2 Loss kernel with the S Ts couplings calculated at each iteration
159+
Updates C according to the L2 Loss kernel with the S Ts couplings
160+
calculated at each iteration
160161
161162
162163
Parameters
163164
----------
164-
p : np.ndarray(N,)
165+
p : ndarray, shape (N,)
165166
weights in the targeted barycenter
166167
lambdas : list of the S spaces' weights
167168
T : list of S np.ndarray(ns,N)
168169
the S Ts couplings calculated at each iteration
169-
Cs : Cs : list of S np.ndarray(ns,ns)
170+
Cs : list of S ndarray, shape(ns,ns)
170171
Metric cost matrices
171172
172173
Returns
173174
----------
174-
C updated
175+
C : ndarray, shape (nt,nt)
176+
updated C matrix
175177
176178
177179
"""
178180
tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
179181
ppt = np.outer(p, p)
180182

181-
return(np.divide(tmpsum, ppt))
183+
return np.divide(tmpsum, ppt)
182184

183185

184186
def update_kl_loss(p, lambdas, T, Cs):
@@ -188,27 +190,28 @@ def update_kl_loss(p, lambdas, T, Cs):
188190
189191
Parameters
190192
----------
191-
p : np.ndarray(N,)
193+
p : ndarray, shape (N,)
192194
weights in the targeted barycenter
193195
lambdas : list of the S spaces' weights
194196
T : list of S np.ndarray(ns,N)
195197
the S Ts couplings calculated at each iteration
196-
Cs : Cs : list of S np.ndarray(ns,ns)
198+
Cs : list of S ndarray, shape(ns,ns)
197199
Metric cost matrices
198200
199201
Returns
200202
----------
201-
C updated
203+
C : ndarray, shape (ns,ns)
204+
updated C matrix
202205
203206
204207
"""
205208
tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
206209
ppt = np.outer(p, p)
207210

208-
return(np.exp(np.divide(tmpsum, ppt)))
211+
return np.exp(np.divide(tmpsum, ppt))
209212

210213

211-
def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
214+
def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False):
212215
"""
213216
Returns the gromov-wasserstein coupling between the two measured similarity matrices
214217
@@ -241,31 +244,28 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1
241244
Metric cost matrix in the source space
242245
C2 : ndarray, shape (nt, nt)
243246
Metric costfr matrix in the target space
244-
p : np.ndarray(ns,)
247+
p : ndarray, shape (ns,)
245248
distribution in the source space
246-
q : np.ndarray(nt)
249+
q : ndarray, shape (nt,)
247250
distribution in the target space
248-
loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
251+
loss_fun : string
252+
loss function used for the solver either 'square_loss' or 'kl_loss'
249253
epsilon : float
250254
Regularization term >0
251-
<<<<<<< HEAD
252255
max_iter : int, optional
253-
=======
254-
numItermax : int, optional
255-
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
256-
Max number of iterations
257-
stopThr : float, optional
256+
Max number of iterations
257+
tol : float, optional
258258
Stop threshold on error (>0)
259259
verbose : bool, optional
260260
Print information along iterations
261261
log : bool, optional
262262
record log if True
263-
forcing : np.ndarray(N,2)
264-
list of forced couplings (where N is the number of forcing)
263+
265264
266265
Returns
267266
-------
268-
T : coupling between the two spaces that minimizes :
267+
T : ndarray, shape (ns, nt)
268+
coupling between the two spaces that minimizes :
269269
\sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
270270
271271
"""
@@ -278,7 +278,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1
278278
cpt = 0
279279
err = 1
280280

281-
while (err > stopThr and cpt < max_iter):
281+
while (err > tol and cpt < max_iter):
282282

283283
Tprev = T
284284

@@ -303,15 +303,15 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1
303303
'It.', 'Err') + '\n' + '-' * 19)
304304
print('{:5d}|{:8e}|'.format(cpt, err))
305305

306-
cpt = cpt + 1
306+
cpt += 1
307307

308308
if log:
309309
return T, log
310310
else:
311311
return T
312312

313313

314-
def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
314+
def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False):
315315
"""
316316
Returns the gromov-wasserstein discrepancy between the two measured similarity matrices
317317
@@ -339,37 +339,36 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=
339339
Metric cost matrix in the source space
340340
C2 : ndarray, shape (nt, nt)
341341
Metric costfr matrix in the target space
342-
p : np.ndarray(ns,)
342+
p : ndarray, shape (ns,)
343343
distribution in the source space
344-
q : np.ndarray(nt)
344+
q : ndarray, shape (nt,)
345345
distribution in the target space
346-
loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
346+
loss_fun : string
347+
loss function used for the solver either 'square_loss' or 'kl_loss'
347348
epsilon : float
348349
Regularization term >0
349350
max_iter : int, optional
350351
Max number of iterations
351-
stopThr : float, optional
352+
tol : float, optional
352353
Stop threshold on error (>0)
353354
verbose : bool, optional
354355
Print information along iterations
355356
log : bool, optional
356357
record log if True
357-
forcing : np.ndarray(N,2)
358-
list of forced couplings (where N is the number of forcing)
359358
360359
Returns
361360
-------
362-
T : coupling between the two spaces that minimizes :
363-
\sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
361+
gw_dist : float
362+
Gromov-Wasserstein distance
364363
365364
"""
366365

367366
if log:
368367
gw, logv = gromov_wasserstein(
369-
C1, C2, p, q, loss_fun, epsilon, max_iter, stopThr, verbose, log)
368+
C1, C2, p, q, loss_fun, epsilon, max_iter, tol, verbose, log)
370369
else:
371370
gw = gromov_wasserstein(C1, C2, p, q, loss_fun,
372-
epsilon, max_iter, stopThr, verbose, log)
371+
epsilon, max_iter, tol, verbose, log)
373372

374373
if loss_fun == 'square_loss':
375374
gw_dist = np.sum(gw * tensor_square_loss(C1, C2, gw))
@@ -383,7 +382,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=
383382
return gw_dist
384383

385384

386-
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
385+
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False):
387386
"""
388387
Returns the gromov-wasserstein barycenters of S measured similarity matrices
389388
@@ -408,7 +407,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
408407
Metric cost matrices
409408
ps : list of S np.ndarray(ns,)
410409
sample weights in the S spaces
411-
p : np.ndarray(N,)
410+
p : ndarray, shape(N,)
412411
weights in the targeted barycenter
413412
lambdas : list of the S spaces' weights
414413
L : tensor-matrix multiplication function based on specific loss function
@@ -418,7 +417,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
418417
Regularization term >0
419418
max_iter : int, optional
420419
Max number of iterations
421-
stopThr : float, optional
420+
tol : float, optional
422421
Stop threshol on error (>0)
423422
verbose : bool, optional
424423
Print information along iterations
@@ -427,7 +426,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
427426
428427
Returns
429428
-------
430-
C : Similarity matrix in the barycenter space (permutated arbitrarily)
429+
C : ndarray, shape (N, N)
430+
Similarity matrix in the barycenter space (permutated arbitrarily)
431431
432432
"""
433433

@@ -446,7 +446,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
446446

447447
error = []
448448

449-
while(err > stopThr and cpt < max_iter):
449+
while(err > tol and cpt < max_iter):
450450
Cprev = C
451451

452452
T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,

test/test_gromov.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_gromov():
1717

1818
xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
1919

20-
xt = [xs[n_samples - (i + 1)] for i in range(n_samples)]
20+
xt = xs[::-1]
2121
xt = np.array(xt)
2222

2323
p = ot.unif(n_samples)

0 commit comments

Comments
 (0)