@@ -40,7 +40,6 @@ def tensor_square_loss(C1, C2, T):
40
40
function as the loss function of Gromow-Wasserstein discrepancy.
41
41
42
42
Where :
43
-
44
43
C1 : Metric cost matrix in the source space
45
44
C2 : Metric cost matrix in the target space
46
45
T : A coupling between those two spaces
@@ -61,13 +60,10 @@ def tensor_square_loss(C1, C2, T):
61
60
T : ndarray, shape (ns, nt)
62
61
Coupling between source and target spaces
63
62
64
-
65
63
Returns
66
64
-------
67
65
tens : ndarray, shape (ns, nt)
68
66
\mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
69
-
70
-
71
67
"""
72
68
73
69
C1 = np .asarray (C1 , dtype = np .float64 )
@@ -119,15 +115,13 @@ def tensor_kl_loss(C1, C2, T):
119
115
T : ndarray, shape (ns, nt)
120
116
Coupling between source and target spaces
121
117
122
-
123
118
Returns
124
119
-------
125
120
tens : ndarray, shape (ns, nt)
126
121
\mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
127
122
128
123
References
129
124
----------
130
-
131
125
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016.
132
126
133
127
"""
@@ -159,7 +153,6 @@ def update_square_loss(p, lambdas, T, Cs):
159
153
Updates C according to the L2 Loss kernel with the S Ts couplings
160
154
calculated at each iteration
161
155
162
-
163
156
Parameters
164
157
----------
165
158
p : ndarray, shape (N,)
@@ -174,8 +167,6 @@ def update_square_loss(p, lambdas, T, Cs):
174
167
----------
175
168
C : ndarray, shape (nt,nt)
176
169
updated C matrix
177
-
178
-
179
170
"""
180
171
tmpsum = sum ([lambdas [s ] * np .dot (T [s ].T , Cs [s ]).dot (T [s ]) for s in range (len (T ))])
181
172
ppt = np .outer (p , p )
@@ -202,8 +193,6 @@ def update_kl_loss(p, lambdas, T, Cs):
202
193
----------
203
194
C : ndarray, shape (ns,ns)
204
195
updated C matrix
205
-
206
-
207
196
"""
208
197
tmpsum = sum ([lambdas [s ] * np .dot (T [s ].T , Cs [s ]).dot (T [s ]) for s in range (len (T ))])
209
198
ppt = np .outer (p , p )
@@ -229,15 +218,13 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9,
229
218
\GW\geq 0
230
219
231
220
Where :
232
-
233
221
C1 : Metric cost matrix in the source space
234
222
C2 : Metric cost matrix in the target space
235
223
p : distribution in the source space
236
224
q : distribution in the target space
237
225
L : loss function to account for the misfit between the similarity matrices
238
226
H : entropy
239
227
240
-
241
228
Parameters
242
229
----------
243
230
C1 : ndarray, shape (ns, ns)
@@ -261,13 +248,11 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9,
261
248
log : bool, optional
262
249
record log if True
263
250
264
-
265
251
Returns
266
252
-------
267
253
T : ndarray, shape (ns, nt)
268
254
coupling between the two spaces that minimizes :
269
255
\sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
270
-
271
256
"""
272
257
273
258
C1 = np .asarray (C1 , dtype = np .float64 )
@@ -322,17 +307,14 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9
322
307
.. math::
323
308
\GW_Dist = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
324
309
325
-
326
310
Where :
327
-
328
311
C1 : Metric cost matrix in the source space
329
312
C2 : Metric cost matrix in the target space
330
313
p : distribution in the source space
331
314
q : distribution in the target space
332
315
L : loss function to account for the misfit between the similarity matrices
333
316
H : entropy
334
317
335
-
336
318
Parameters
337
319
----------
338
320
C1 : ndarray, shape (ns, ns)
@@ -360,7 +342,6 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9
360
342
-------
361
343
gw_dist : float
362
344
Gromov-Wasserstein distance
363
-
364
345
"""
365
346
366
347
if log :
@@ -428,7 +409,6 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
428
409
-------
429
410
C : ndarray, shape (N, N)
430
411
Similarity matrix in the barycenter space (permutated arbitrarily)
431
-
432
412
"""
433
413
434
414
S = len (Cs )
0 commit comments