@@ -385,6 +385,20 @@ def test_gromov_barycenter(nx):
385
385
np .testing .assert_allclose (Cb , Cbb , atol = 1e-06 )
386
386
np .testing .assert_allclose (Cbb .shape , (n_samples , n_samples ))
387
387
388
+ # test of gromov_barycenters with `log` on
389
+ Cb_ , err_ = ot .gromov .gromov_barycenters (
390
+ n_samples , [C1 , C2 ], [p1 , p2 ], p , [.5 , .5 ],
391
+ 'square_loss' , max_iter = 100 , tol = 1e-3 , verbose = True , random_state = 42 , log = True
392
+ )
393
+ Cbb_ , errb_ = ot .gromov .gromov_barycenters (
394
+ n_samples , [C1b , C2b ], [p1b , p2b ], pb , [.5 , .5 ],
395
+ 'square_loss' , max_iter = 100 , tol = 1e-3 , verbose = True , random_state = 42 , log = True
396
+ )
397
+ Cbb_ = nx .to_numpy (Cbb_ )
398
+ np .testing .assert_allclose (Cb_ , Cbb_ , atol = 1e-06 )
399
+ np .testing .assert_array_almost_equal (err_ ['err' ], errb_ ['err' ])
400
+ np .testing .assert_allclose (Cbb_ .shape , (n_samples , n_samples ))
401
+
388
402
Cb2 = ot .gromov .gromov_barycenters (
389
403
n_samples , [C1 , C2 ], [p1 , p2 ], p , [.5 , .5 ],
390
404
'kl_loss' , max_iter = 100 , tol = 1e-3 , random_state = 42
@@ -396,6 +410,20 @@ def test_gromov_barycenter(nx):
396
410
np .testing .assert_allclose (Cb2 , Cb2b , atol = 1e-06 )
397
411
np .testing .assert_allclose (Cb2b .shape , (n_samples , n_samples ))
398
412
413
+ # test of gromov_barycenters with `log` on
414
+ Cb2_ , err2_ = ot .gromov .gromov_barycenters (
415
+ n_samples , [C1 , C2 ], [p1 , p2 ], p , [.5 , .5 ],
416
+ 'kl_loss' , max_iter = 100 , tol = 1e-3 , verbose = True , random_state = 42 , log = True
417
+ )
418
+ Cb2b_ , err2b_ = ot .gromov .gromov_barycenters (
419
+ n_samples , [C1b , C2b ], [p1b , p2b ], pb , [.5 , .5 ],
420
+ 'kl_loss' , max_iter = 100 , tol = 1e-3 , verbose = True , random_state = 42 , log = True
421
+ )
422
+ Cb2b_ = nx .to_numpy (Cb2b_ )
423
+ np .testing .assert_allclose (Cb2_ , Cb2b_ , atol = 1e-06 )
424
+ np .testing .assert_array_almost_equal (err2_ ['err' ], err2_ ['err' ])
425
+ np .testing .assert_allclose (Cb2b_ .shape , (n_samples , n_samples ))
426
+
399
427
400
428
@pytest .mark .filterwarnings ("ignore:divide" )
401
429
def test_gromov_entropic_barycenter (nx ):
@@ -429,6 +457,20 @@ def test_gromov_entropic_barycenter(nx):
429
457
np .testing .assert_allclose (Cb , Cbb , atol = 1e-06 )
430
458
np .testing .assert_allclose (Cbb .shape , (n_samples , n_samples ))
431
459
460
+ # test of entropic_gromov_barycenters with `log` on
461
+ Cb_ , err_ = ot .gromov .entropic_gromov_barycenters (
462
+ n_samples , [C1 , C2 ], [p1 , p2 ], p , [.5 , .5 ],
463
+ 'square_loss' , 1e-3 , max_iter = 100 , tol = 1e-3 , verbose = True , random_state = 42 , log = True
464
+ )
465
+ Cbb_ , errb_ = ot .gromov .entropic_gromov_barycenters (
466
+ n_samples , [C1b , C2b ], [p1b , p2b ], pb , [.5 , .5 ],
467
+ 'square_loss' , 1e-3 , max_iter = 100 , tol = 1e-3 , verbose = True , random_state = 42 , log = True
468
+ )
469
+ Cbb_ = nx .to_numpy (Cbb_ )
470
+ np .testing .assert_allclose (Cb_ , Cbb_ , atol = 1e-06 )
471
+ np .testing .assert_array_almost_equal (err_ ['err' ], errb_ ['err' ])
472
+ np .testing .assert_allclose (Cbb_ .shape , (n_samples , n_samples ))
473
+
432
474
Cb2 = ot .gromov .entropic_gromov_barycenters (
433
475
n_samples , [C1 , C2 ], [p1 , p2 ], p , [.5 , .5 ],
434
476
'kl_loss' , 1e-3 , max_iter = 100 , tol = 1e-3 , random_state = 42
@@ -440,6 +482,20 @@ def test_gromov_entropic_barycenter(nx):
440
482
np .testing .assert_allclose (Cb2 , Cb2b , atol = 1e-06 )
441
483
np .testing .assert_allclose (Cb2b .shape , (n_samples , n_samples ))
442
484
485
+ # test of entropic_gromov_barycenters with `log` on
486
+ Cb2_ , err2_ = ot .gromov .entropic_gromov_barycenters (
487
+ n_samples , [C1 , C2 ], [p1 , p2 ], p , [.5 , .5 ],
488
+ 'kl_loss' , 1e-3 , max_iter = 100 , tol = 1e-3 , verbose = True , random_state = 42 , log = True
489
+ )
490
+ Cb2b_ , err2b_ = ot .gromov .entropic_gromov_barycenters (
491
+ n_samples , [C1b , C2b ], [p1b , p2b ], pb , [.5 , .5 ],
492
+ 'kl_loss' , 1e-3 , max_iter = 100 , tol = 1e-3 , verbose = True , random_state = 42 , log = True
493
+ )
494
+ Cb2b_ = nx .to_numpy (Cb2b_ )
495
+ np .testing .assert_allclose (Cb2_ , Cb2b_ , atol = 1e-06 )
496
+ np .testing .assert_array_almost_equal (err2_ ['err' ], err2_ ['err' ])
497
+ np .testing .assert_allclose (Cb2b_ .shape , (n_samples , n_samples ))
498
+
443
499
444
500
def test_fgw (nx ):
445
501
n_samples = 50 # nb samples
0 commit comments