Skip to content

Commit b3dc68f

Browse files
cshjinrflamary
andauthored
[MRG] Fix issue 317 (#318)
* Fix issue 317 * Update with docs and tests Co-authored-by: Rémi Flamary <[email protected]>
1 parent ca69658 commit b3dc68f

File tree

2 files changed

+70
-10
lines changed

2 files changed

+70
-10
lines changed

ot/gromov.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,6 +1368,8 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
13681368
-------
13691369
C : array-like, shape (`N`, `N`)
13701370
Similarity matrix in the barycenter space (permutated arbitrarily)
1371+
log : dict
1372+
Log dictionary of error during iterations. Return only if `log=True` in parameters.
13711373
13721374
References
13731375
----------
@@ -1401,7 +1403,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
14011403
Cprev = C
14021404

14031405
T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
1404-
max_iter, 1e-4, verbose, log) for s in range(S)]
1406+
max_iter, 1e-4, verbose, log=False) for s in range(S)]
14051407
if loss_fun == 'square_loss':
14061408
C = update_square_loss(p, lambdas, T, Cs)
14071409

@@ -1414,9 +1416,6 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
14141416
err = nx.norm(C - Cprev)
14151417
error.append(err)
14161418

1417-
if log:
1418-
log['err'].append(err)
1419-
14201419
if verbose:
14211420
if cpt % 200 == 0:
14221421
print('{:5s}|{:12s}'.format(
@@ -1425,7 +1424,10 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
14251424

14261425
cpt += 1
14271426

1428-
return C
1427+
if log:
1428+
return C, {"err": error}
1429+
else:
1430+
return C
14291431

14301432

14311433
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
@@ -1479,6 +1481,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
14791481
-------
14801482
C : array-like, shape (`N`, `N`)
14811483
Similarity matrix in the barycenter space (permutated arbitrarily)
1484+
log : dict
1485+
Log dictionary of error during iterations. Return only if `log=True` in parameters.
14821486
14831487
References
14841488
----------
@@ -1513,7 +1517,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
15131517
Cprev = C
15141518

15151519
T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun,
1516-
numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=log) for s in range(S)]
1520+
numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=False) for s in range(S)]
15171521
if loss_fun == 'square_loss':
15181522
C = update_square_loss(p, lambdas, T, Cs)
15191523

@@ -1526,9 +1530,6 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
15261530
err = nx.norm(C - Cprev)
15271531
error.append(err)
15281532

1529-
if log:
1530-
log['err'].append(err)
1531-
15321533
if verbose:
15331534
if cpt % 200 == 0:
15341535
print('{:5s}|{:12s}'.format(
@@ -1537,7 +1538,10 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
15371538

15381539
cpt += 1
15391540

1540-
return C
1541+
if log:
1542+
return C, {"err": error}
1543+
else:
1544+
return C
15411545

15421546

15431547
def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,

test/test_gromov.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,20 @@ def test_gromov_barycenter(nx):
385385
np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
386386
np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
387387

388+
# test of gromov_barycenters with `log` on
389+
Cb_, err_ = ot.gromov.gromov_barycenters(
390+
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
391+
'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
392+
)
393+
Cbb_, errb_ = ot.gromov.gromov_barycenters(
394+
n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
395+
'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
396+
)
397+
Cbb_ = nx.to_numpy(Cbb_)
398+
np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06)
399+
np.testing.assert_array_almost_equal(err_['err'], errb_['err'])
400+
np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples))
401+
388402
Cb2 = ot.gromov.gromov_barycenters(
389403
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
390404
'kl_loss', max_iter=100, tol=1e-3, random_state=42
@@ -396,6 +410,20 @@ def test_gromov_barycenter(nx):
396410
np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
397411
np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))
398412

413+
# test of gromov_barycenters with `log` on
414+
Cb2_, err2_ = ot.gromov.gromov_barycenters(
415+
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
416+
'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
417+
)
418+
Cb2b_, err2b_ = ot.gromov.gromov_barycenters(
419+
n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
420+
'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
421+
)
422+
Cb2b_ = nx.to_numpy(Cb2b_)
423+
np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06)
424+
np.testing.assert_array_almost_equal(err2_['err'], err2_['err'])
425+
np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples))
426+
399427

400428
@pytest.mark.filterwarnings("ignore:divide")
401429
def test_gromov_entropic_barycenter(nx):
@@ -429,6 +457,20 @@ def test_gromov_entropic_barycenter(nx):
429457
np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
430458
np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
431459

460+
# test of entropic_gromov_barycenters with `log` on
461+
Cb_, err_ = ot.gromov.entropic_gromov_barycenters(
462+
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
463+
'square_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
464+
)
465+
Cbb_, errb_ = ot.gromov.entropic_gromov_barycenters(
466+
n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
467+
'square_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
468+
)
469+
Cbb_ = nx.to_numpy(Cbb_)
470+
np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06)
471+
np.testing.assert_array_almost_equal(err_['err'], errb_['err'])
472+
np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples))
473+
432474
Cb2 = ot.gromov.entropic_gromov_barycenters(
433475
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
434476
'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42
@@ -440,6 +482,20 @@ def test_gromov_entropic_barycenter(nx):
440482
np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
441483
np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))
442484

485+
# test of entropic_gromov_barycenters with `log` on
486+
Cb2_, err2_ = ot.gromov.entropic_gromov_barycenters(
487+
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
488+
'kl_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
489+
)
490+
Cb2b_, err2b_ = ot.gromov.entropic_gromov_barycenters(
491+
n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
492+
'kl_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
493+
)
494+
Cb2b_ = nx.to_numpy(Cb2b_)
495+
np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06)
496+
np.testing.assert_array_almost_equal(err2_['err'], err2_['err'])
497+
np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples))
498+
443499

444500
def test_fgw(nx):
445501
n_samples = 50 # nb samples

0 commit comments

Comments
 (0)