Skip to content

Commit 24784ed

Browse files
Nicolas CourtyNicolas Courty
authored andcommitted
Corrections on Gromov
1 parent 36bf599 commit 24784ed

File tree

1 file changed

+0
-20
lines changed

1 file changed

+0
-20
lines changed

ot/gromov.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def tensor_square_loss(C1, C2, T):
4040
function as the loss function of Gromow-Wasserstein discrepancy.
4141
4242
Where :
43-
4443
C1 : Metric cost matrix in the source space
4544
C2 : Metric cost matrix in the target space
4645
T : A coupling between those two spaces
@@ -61,13 +60,10 @@ def tensor_square_loss(C1, C2, T):
6160
T : ndarray, shape (ns, nt)
6261
Coupling between source and target spaces
6362
64-
6563
Returns
6664
-------
6765
tens : ndarray, shape (ns, nt)
6866
\mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
69-
70-
7167
"""
7268

7369
C1 = np.asarray(C1, dtype=np.float64)
@@ -119,15 +115,13 @@ def tensor_kl_loss(C1, C2, T):
119115
T : ndarray, shape (ns, nt)
120116
Coupling between source and target spaces
121117
122-
123118
Returns
124119
-------
125120
tens : ndarray, shape (ns, nt)
126121
\mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
127122
128123
References
129124
----------
130-
131125
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016.
132126
133127
"""
@@ -159,7 +153,6 @@ def update_square_loss(p, lambdas, T, Cs):
159153
Updates C according to the L2 Loss kernel with the S Ts couplings
160154
calculated at each iteration
161155
162-
163156
Parameters
164157
----------
165158
p : ndarray, shape (N,)
@@ -174,8 +167,6 @@ def update_square_loss(p, lambdas, T, Cs):
174167
----------
175168
C : ndarray, shape (nt,nt)
176169
updated C matrix
177-
178-
179170
"""
180171
tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
181172
ppt = np.outer(p, p)
@@ -202,8 +193,6 @@ def update_kl_loss(p, lambdas, T, Cs):
202193
----------
203194
C : ndarray, shape (ns,ns)
204195
updated C matrix
205-
206-
207196
"""
208197
tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
209198
ppt = np.outer(p, p)
@@ -229,15 +218,13 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9,
229218
\GW\geq 0
230219
231220
Where :
232-
233221
C1 : Metric cost matrix in the source space
234222
C2 : Metric cost matrix in the target space
235223
p : distribution in the source space
236224
q : distribution in the target space
237225
L : loss function to account for the misfit between the similarity matrices
238226
H : entropy
239227
240-
241228
Parameters
242229
----------
243230
C1 : ndarray, shape (ns, ns)
@@ -261,13 +248,11 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9,
261248
log : bool, optional
262249
record log if True
263250
264-
265251
Returns
266252
-------
267253
T : ndarray, shape (ns, nt)
268254
coupling between the two spaces that minimizes :
269255
\sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
270-
271256
"""
272257

273258
C1 = np.asarray(C1, dtype=np.float64)
@@ -322,17 +307,14 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9
322307
.. math::
323308
\GW_Dist = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
324309
325-
326310
Where :
327-
328311
C1 : Metric cost matrix in the source space
329312
C2 : Metric cost matrix in the target space
330313
p : distribution in the source space
331314
q : distribution in the target space
332315
L : loss function to account for the misfit between the similarity matrices
333316
H : entropy
334317
335-
336318
Parameters
337319
----------
338320
C1 : ndarray, shape (ns, ns)
@@ -360,7 +342,6 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9
360342
-------
361343
gw_dist : float
362344
Gromov-Wasserstein distance
363-
364345
"""
365346

366347
if log:
@@ -428,7 +409,6 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
428409
-------
429410
C : ndarray, shape (N, N)
430411
Similarity matrix in the barycenter space (permutated arbitrarily)
431-
432412
"""
433413

434414
S = len(Cs)

0 commit comments

Comments
 (0)