@@ -122,7 +122,9 @@ def tensor_kl_loss(C1, C2, T):
122
122
123
123
References
124
124
----------
125
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016.
125
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
126
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
127
+ International Conference on Machine Learning (ICML). 2016.
126
128
127
129
"""
128
130
@@ -157,7 +159,8 @@ def update_square_loss(p, lambdas, T, Cs):
157
159
----------
158
160
p : ndarray, shape (N,)
159
161
weights in the targeted barycenter
160
- lambdas : list of the S spaces' weights
162
+ lambdas : list of float
163
+ list of the S spaces' weights
161
164
T : list of S np.ndarray(ns,N)
162
165
the S Ts couplings calculated at each iteration
163
166
Cs : list of S ndarray, shape(ns,ns)
@@ -168,7 +171,8 @@ def update_square_loss(p, lambdas, T, Cs):
168
171
C : ndarray, shape (nt,nt)
169
172
updated C matrix
170
173
"""
171
- tmpsum = sum ([lambdas [s ] * np .dot (T [s ].T , Cs [s ]).dot (T [s ]) for s in range (len (T ))])
174
+ tmpsum = sum ([lambdas [s ] * np .dot (T [s ].T , Cs [s ]).dot (T [s ])
175
+ for s in range (len (T ))])
172
176
ppt = np .outer (p , p )
173
177
174
178
return np .divide (tmpsum , ppt )
@@ -194,13 +198,15 @@ def update_kl_loss(p, lambdas, T, Cs):
194
198
C : ndarray, shape (ns,ns)
195
199
updated C matrix
196
200
"""
197
- tmpsum = sum ([lambdas [s ] * np .dot (T [s ].T , Cs [s ]).dot (T [s ]) for s in range (len (T ))])
201
+ tmpsum = sum ([lambdas [s ] * np .dot (T [s ].T , Cs [s ]).dot (T [s ])
202
+ for s in range (len (T ))])
198
203
ppt = np .outer (p , p )
199
204
200
205
return np .exp (np .divide (tmpsum , ppt ))
201
206
202
207
203
- def gromov_wasserstein (C1 , C2 , p , q , loss_fun , epsilon , max_iter = 1000 , tol = 1e-9 , verbose = False , log = False ):
208
+ def gromov_wasserstein (C1 , C2 , p , q , loss_fun , epsilon ,
209
+ max_iter = 1000 , tol = 1e-9 , verbose = False , log = False ):
204
210
"""
205
211
Returns the gromov-wasserstein coupling between the two measured similarity matrices
206
212
@@ -276,7 +282,8 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9,
276
282
T = sinkhorn (p , q , tens , epsilon )
277
283
278
284
if cpt % 10 == 0 :
279
- # we can speed up the process by checking for the error only all the 10th iterations
285
+ # we can speed up the process by checking for the error only all
286
+ # the 10th iterations
280
287
err = np .linalg .norm (T - Tprev )
281
288
282
289
if log :
@@ -296,7 +303,8 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9,
296
303
return T
297
304
298
305
299
- def gromov_wasserstein2 (C1 , C2 , p , q , loss_fun , epsilon , max_iter = 1000 , tol = 1e-9 , verbose = False , log = False ):
306
+ def gromov_wasserstein2 (C1 , C2 , p , q , loss_fun , epsilon ,
307
+ max_iter = 1000 , tol = 1e-9 , verbose = False , log = False ):
300
308
"""
301
309
Returns the gromov-wasserstein discrepancy between the two measured similarity matrices
302
310
@@ -363,7 +371,8 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9
363
371
return gw_dist
364
372
365
373
366
- def gromov_barycenters (N , Cs , ps , p , lambdas , loss_fun , epsilon , max_iter = 1000 , tol = 1e-9 , verbose = False , log = False ):
374
+ def gromov_barycenters (N , Cs , ps , p , lambdas , loss_fun , epsilon ,
375
+ max_iter = 1000 , tol = 1e-9 , verbose = False , log = False , init_C = None ):
367
376
"""
368
377
Returns the gromov-wasserstein barycenters of S measured similarity matrices
369
378
@@ -390,7 +399,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
390
399
sample weights in the S spaces
391
400
p : ndarray, shape(N,)
392
401
weights in the targeted barycenter
393
- lambdas : list of the S spaces' weights
402
+ lambdas : list of float
403
+ list of the S spaces' weights
394
404
L : tensor-matrix multiplication function based on specific loss function
395
405
update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel
396
406
with the S Ts couplings calculated at each iteration
@@ -404,6 +414,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
404
414
Print information along iterations
405
415
log : bool, optional
406
416
record log if True
417
+ init_C : bool, ndarray, shape(N,N)
418
+ random initial value for the C matrix provided by user
407
419
408
420
Returns
409
421
-------
@@ -416,10 +428,13 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
416
428
Cs = [np .asarray (Cs [s ], dtype = np .float64 ) for s in range (S )]
417
429
lambdas = np .asarray (lambdas , dtype = np .float64 )
418
430
419
- # Initialization of C : random SPD matrix
420
- xalea = np .random .randn (N , 2 )
421
- C = dist (xalea , xalea )
422
- C /= C .max ()
431
+ # Initialization of C : random SPD matrix (if not provided by user)
432
+ if init_C is None :
433
+ xalea = np .random .randn (N , 2 )
434
+ C = dist (xalea , xalea )
435
+ C /= C .max ()
436
+ else :
437
+ C = init_C
423
438
424
439
cpt = 0
425
440
err = 1
@@ -438,7 +453,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
438
453
C = update_kl_loss (p , lambdas , T , Cs )
439
454
440
455
if cpt % 10 == 0 :
441
- # we can speed up the process by checking for the error only all the 10th iterations
456
+ # we can speed up the process by checking for the error only all
457
+ # the 10th iterations
442
458
err = np .linalg .norm (C - Cprev )
443
459
error .append (err )
444
460
0 commit comments