@@ -58,13 +58,13 @@ def tensor_square_loss(C1, C2, T):
58
58
Metric cost matrix in the source space
59
59
C2 : ndarray, shape (nt, nt)
60
60
Metric costfr matrix in the target space
61
- T : np. ndarray(ns,nt)
61
+ T : ndarray, shape (ns, nt)
62
62
Coupling between source and target spaces
63
63
64
64
65
65
Returns
66
66
-------
67
- tens : (ns* nt) ndarray
67
+ tens : ndarray, shape (ns, nt)
68
68
\mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
69
69
70
70
@@ -89,7 +89,7 @@ def h2(b):
89
89
tens = - np .dot (h1 (C1 ), T ).dot (h2 (C2 ).T )
90
90
tens -= tens .min ()
91
91
92
- return np . array ( tens )
92
+ return tens
93
93
94
94
95
95
def tensor_kl_loss (C1 , C2 , T ):
@@ -116,13 +116,13 @@ def tensor_kl_loss(C1, C2, T):
116
116
Metric cost matrix in the source space
117
117
C2 : ndarray, shape (nt, nt)
118
118
Metric costfr matrix in the target space
119
- T : np. ndarray(ns,nt)
119
+ T : ndarray, shape (ns, nt)
120
120
Coupling between source and target spaces
121
121
122
122
123
123
Returns
124
124
-------
125
- tens : (ns* nt) ndarray
125
+ tens : ndarray, shape (ns, nt)
126
126
\mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
127
127
128
128
References
@@ -151,34 +151,36 @@ def h2(b):
151
151
tens = - np .dot (h1 (C1 ), T ).dot (h2 (C2 ).T )
152
152
tens -= tens .min ()
153
153
154
- return np . array ( tens )
154
+ return tens
155
155
156
156
157
157
def update_square_loss (p , lambdas , T , Cs ):
158
158
"""
159
- Updates C according to the L2 Loss kernel with the S Ts couplings calculated at each iteration
159
+ Updates C according to the L2 Loss kernel with the S Ts couplings
160
+ calculated at each iteration
160
161
161
162
162
163
Parameters
163
164
----------
164
- p : np. ndarray(N,)
165
+ p : ndarray, shape (N,)
165
166
weights in the targeted barycenter
166
167
lambdas : list of the S spaces' weights
167
168
T : list of S np.ndarray(ns,N)
168
169
the S Ts couplings calculated at each iteration
169
- Cs : Cs : list of S np. ndarray(ns,ns)
170
+ Cs : list of S ndarray, shape (ns,ns)
170
171
Metric cost matrices
171
172
172
173
Returns
173
174
----------
174
- C updated
175
+ C : ndarray, shape (nt,nt)
176
+ updated C matrix
175
177
176
178
177
179
"""
178
180
tmpsum = sum ([lambdas [s ] * np .dot (T [s ].T , Cs [s ]).dot (T [s ]) for s in range (len (T ))])
179
181
ppt = np .outer (p , p )
180
182
181
- return ( np .divide (tmpsum , ppt ) )
183
+ return np .divide (tmpsum , ppt )
182
184
183
185
184
186
def update_kl_loss (p , lambdas , T , Cs ):
@@ -188,27 +190,28 @@ def update_kl_loss(p, lambdas, T, Cs):
188
190
189
191
Parameters
190
192
----------
191
- p : np. ndarray(N,)
193
+ p : ndarray, shape (N,)
192
194
weights in the targeted barycenter
193
195
lambdas : list of the S spaces' weights
194
196
T : list of S np.ndarray(ns,N)
195
197
the S Ts couplings calculated at each iteration
196
- Cs : Cs : list of S np. ndarray(ns,ns)
198
+ Cs : list of S ndarray, shape (ns,ns)
197
199
Metric cost matrices
198
200
199
201
Returns
200
202
----------
201
- C updated
203
+ C : ndarray, shape (ns,ns)
204
+ updated C matrix
202
205
203
206
204
207
"""
205
208
tmpsum = sum ([lambdas [s ] * np .dot (T [s ].T , Cs [s ]).dot (T [s ]) for s in range (len (T ))])
206
209
ppt = np .outer (p , p )
207
210
208
- return ( np .exp (np .divide (tmpsum , ppt ) ))
211
+ return np .exp (np .divide (tmpsum , ppt ))
209
212
210
213
211
- def gromov_wasserstein (C1 , C2 , p , q , loss_fun , epsilon , max_iter = 1000 , stopThr = 1e-9 , verbose = False , log = False ):
214
+ def gromov_wasserstein (C1 , C2 , p , q , loss_fun , epsilon , max_iter = 1000 , tol = 1e-9 , verbose = False , log = False ):
212
215
"""
213
216
Returns the gromov-wasserstein coupling between the two measured similarity matrices
214
217
@@ -241,31 +244,28 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1
241
244
Metric cost matrix in the source space
242
245
C2 : ndarray, shape (nt, nt)
243
246
Metric costfr matrix in the target space
244
- p : np. ndarray(ns,)
247
+ p : ndarray, shape (ns,)
245
248
distribution in the source space
246
- q : np. ndarray(nt)
249
+ q : ndarray, shape (nt, )
247
250
distribution in the target space
248
- loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
251
+ loss_fun : string
252
+ loss function used for the solver either 'square_loss' or 'kl_loss'
249
253
epsilon : float
250
254
Regularization term >0
251
- <<<<<<< HEAD
252
255
max_iter : int, optional
253
- =======
254
- numItermax : int, optional
255
- >>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
256
- Max number of iterations
257
- stopThr : float, optional
256
+ Max number of iterations
257
+ tol : float, optional
258
258
Stop threshold on error (>0)
259
259
verbose : bool, optional
260
260
Print information along iterations
261
261
log : bool, optional
262
262
record log if True
263
- forcing : np.ndarray(N,2)
264
- list of forced couplings (where N is the number of forcing)
263
+
265
264
266
265
Returns
267
266
-------
268
- T : coupling between the two spaces that minimizes :
267
+ T : ndarray, shape (ns, nt)
268
+ coupling between the two spaces that minimizes :
269
269
\sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
270
270
271
271
"""
@@ -278,7 +278,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1
278
278
cpt = 0
279
279
err = 1
280
280
281
- while (err > stopThr and cpt < max_iter ):
281
+ while (err > tol and cpt < max_iter ):
282
282
283
283
Tprev = T
284
284
@@ -303,15 +303,15 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1
303
303
'It.' , 'Err' ) + '\n ' + '-' * 19 )
304
304
print ('{:5d}|{:8e}|' .format (cpt , err ))
305
305
306
- cpt = cpt + 1
306
+ cpt += 1
307
307
308
308
if log :
309
309
return T , log
310
310
else :
311
311
return T
312
312
313
313
314
- def gromov_wasserstein2 (C1 , C2 , p , q , loss_fun , epsilon , max_iter = 1000 , stopThr = 1e-9 , verbose = False , log = False ):
314
+ def gromov_wasserstein2 (C1 , C2 , p , q , loss_fun , epsilon , max_iter = 1000 , tol = 1e-9 , verbose = False , log = False ):
315
315
"""
316
316
Returns the gromov-wasserstein discrepancy between the two measured similarity matrices
317
317
@@ -339,37 +339,36 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=
339
339
Metric cost matrix in the source space
340
340
C2 : ndarray, shape (nt, nt)
341
341
Metric costfr matrix in the target space
342
- p : np. ndarray(ns,)
342
+ p : ndarray, shape (ns,)
343
343
distribution in the source space
344
- q : np. ndarray(nt)
344
+ q : ndarray, shape (nt, )
345
345
distribution in the target space
346
- loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
346
+ loss_fun : string
347
+ loss function used for the solver either 'square_loss' or 'kl_loss'
347
348
epsilon : float
348
349
Regularization term >0
349
350
max_iter : int, optional
350
351
Max number of iterations
351
- stopThr : float, optional
352
+ tol : float, optional
352
353
Stop threshold on error (>0)
353
354
verbose : bool, optional
354
355
Print information along iterations
355
356
log : bool, optional
356
357
record log if True
357
- forcing : np.ndarray(N,2)
358
- list of forced couplings (where N is the number of forcing)
359
358
360
359
Returns
361
360
-------
362
- T : coupling between the two spaces that minimizes :
363
- \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
361
+ gw_dist : float
362
+ Gromov-Wasserstein distance
364
363
365
364
"""
366
365
367
366
if log :
368
367
gw , logv = gromov_wasserstein (
369
- C1 , C2 , p , q , loss_fun , epsilon , max_iter , stopThr , verbose , log )
368
+ C1 , C2 , p , q , loss_fun , epsilon , max_iter , tol , verbose , log )
370
369
else :
371
370
gw = gromov_wasserstein (C1 , C2 , p , q , loss_fun ,
372
- epsilon , max_iter , stopThr , verbose , log )
371
+ epsilon , max_iter , tol , verbose , log )
373
372
374
373
if loss_fun == 'square_loss' :
375
374
gw_dist = np .sum (gw * tensor_square_loss (C1 , C2 , gw ))
@@ -383,7 +382,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=
383
382
return gw_dist
384
383
385
384
386
- def gromov_barycenters (N , Cs , ps , p , lambdas , loss_fun , epsilon , max_iter = 1000 , stopThr = 1e-9 , verbose = False , log = False ):
385
+ def gromov_barycenters (N , Cs , ps , p , lambdas , loss_fun , epsilon , max_iter = 1000 , tol = 1e-9 , verbose = False , log = False ):
387
386
"""
388
387
Returns the gromov-wasserstein barycenters of S measured similarity matrices
389
388
@@ -408,7 +407,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
408
407
Metric cost matrices
409
408
ps : list of S np.ndarray(ns,)
410
409
sample weights in the S spaces
411
- p : np. ndarray(N,)
410
+ p : ndarray, shape (N,)
412
411
weights in the targeted barycenter
413
412
lambdas : list of the S spaces' weights
414
413
L : tensor-matrix multiplication function based on specific loss function
@@ -418,7 +417,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
418
417
Regularization term >0
419
418
max_iter : int, optional
420
419
Max number of iterations
421
- stopThr : float, optional
420
+ tol : float, optional
422
421
Stop threshol on error (>0)
423
422
verbose : bool, optional
424
423
Print information along iterations
@@ -427,7 +426,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
427
426
428
427
Returns
429
428
-------
430
- C : Similarity matrix in the barycenter space (permutated arbitrarily)
429
+ C : ndarray, shape (N, N)
430
+ Similarity matrix in the barycenter space (permutated arbitrarily)
431
431
432
432
"""
433
433
@@ -446,7 +446,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
446
446
447
447
error = []
448
448
449
- while (err > stopThr and cpt < max_iter ):
449
+ while (err > tol and cpt < max_iter ):
450
450
Cprev = C
451
451
452
452
T = [gromov_wasserstein (Cs [s ], C , ps [s ], p , loss_fun , epsilon ,
0 commit comments