@@ -89,7 +89,9 @@ def test_sinkhorn_lpl1_transport_class(nx):
89
89
# test its computed
90
90
otda .fit (Xs = Xs , ys = ys , Xt = Xt )
91
91
assert hasattr (otda , "cost_" )
92
+ assert not np .any (np .isnan (nx .to_numpy (otda .cost_ ))), "cost is finite"
92
93
assert hasattr (otda , "coupling_" )
94
+ assert np .all (np .isfinite (nx .to_numpy (otda .coupling_ ))), "coupling is finite"
93
95
94
96
# test dimensions of coupling
95
97
assert_equal (otda .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
@@ -148,7 +150,7 @@ def test_sinkhorn_lpl1_transport_class(nx):
148
150
n_semisup = nx .sum (otda_semi .cost_ )
149
151
150
152
# check that the cost matrix norms are indeed different
151
- assert n_unsup != n_semisup , "semisupervised mode not working"
153
+ assert np . allclose ( n_unsup , n_semisup , atol = 1e-7 ), "semisupervised mode is not working"
152
154
153
155
# check that the coupling forbids mass transport between labeled source
154
156
# and labeled target samples
@@ -238,7 +240,7 @@ def test_sinkhorn_l1l2_transport_class(nx):
238
240
n_semisup = nx .sum (otda_semi .cost_ )
239
241
240
242
# check that the cost matrix norms are indeed different
241
- assert n_unsup != n_semisup , "semisupervised mode not working"
243
+ assert np . allclose ( n_unsup , n_semisup , atol = 1e-7 ), "semisupervised mode is not working"
242
244
243
245
# check that the coupling forbids mass transport between labeled source
244
246
# and labeled target samples
@@ -331,7 +333,7 @@ def test_sinkhorn_transport_class(nx):
331
333
n_semisup = nx .sum (otda_semi .cost_ )
332
334
333
335
# check that the cost matrix norms are indeed different
334
- assert n_unsup != n_semisup , "semisupervised mode not working"
336
+ assert np . allclose ( n_unsup , n_semisup , atol = 1e-7 ), "semisupervised mode is not working"
335
337
336
338
# check that the coupling forbids mass transport between labeled source
337
339
# and labeled target samples
@@ -371,6 +373,10 @@ def test_unbalanced_sinkhorn_transport_class(nx):
371
373
# test dimensions of coupling
372
374
assert_equal (otda .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
373
375
assert_equal (otda .coupling_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
376
+ assert not np .any (np .isnan (nx .to_numpy (otda .cost_ ))), "cost is finite"
377
+
378
+ # test coupling
379
+ assert np .all (np .isfinite (nx .to_numpy (otda .coupling_ ))), "coupling is finite"
374
380
375
381
# test transform
376
382
transp_Xs = otda .transform (Xs = Xs )
@@ -409,19 +415,22 @@ def test_unbalanced_sinkhorn_transport_class(nx):
409
415
# test unsupervised vs semi-supervised mode
410
416
otda_unsup = ot .da .SinkhornTransport ()
411
417
otda_unsup .fit (Xs = Xs , Xt = Xt )
418
+ assert not np .any (np .isnan (nx .to_numpy (otda_unsup .cost_ ))), "cost is finite"
412
419
n_unsup = nx .sum (otda_unsup .cost_ )
413
420
414
421
otda_semi = ot .da .SinkhornTransport ()
415
422
otda_semi .fit (Xs = Xs , ys = ys , Xt = Xt , yt = yt )
423
+ assert not np .any (np .isnan (nx .to_numpy (otda_semi .cost_ ))), "cost is finite"
416
424
assert_equal (otda_semi .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
417
425
n_semisup = nx .sum (otda_semi .cost_ )
418
426
419
427
# check that the cost matrix norms are indeed different
420
- assert n_unsup != n_semisup , "semisupervised mode not working"
428
+ assert np . allclose ( n_unsup , n_semisup , atol = 1e-7 ), "semisupervised mode is not working"
421
429
422
430
# check everything runs well with log=True
423
431
otda = ot .da .SinkhornTransport (log = True )
424
432
otda .fit (Xs = Xs , ys = ys , Xt = Xt )
433
+ assert not np .any (np .isnan (nx .to_numpy (otda .cost_ ))), "cost is finite"
425
434
assert len (otda .log_ .keys ()) != 0
426
435
427
436
@@ -448,7 +457,9 @@ def test_emd_transport_class(nx):
448
457
449
458
# test dimensions of coupling
450
459
assert_equal (otda .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
460
+ assert not np .any (np .isnan (nx .to_numpy (otda .cost_ ))), "cost is finite"
451
461
assert_equal (otda .coupling_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
462
+ assert np .all (np .isfinite (nx .to_numpy (otda .coupling_ ))), "coupling is finite"
452
463
453
464
# test margin constraints
454
465
mu_s = unif (ns )
@@ -495,15 +506,22 @@ def test_emd_transport_class(nx):
495
506
# test unsupervised vs semi-supervised mode
496
507
otda_unsup = ot .da .EMDTransport ()
497
508
otda_unsup .fit (Xs = Xs , ys = ys , Xt = Xt )
509
+ assert_equal (otda_unsup .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
510
+ assert not np .any (np .isnan (nx .to_numpy (otda_unsup .cost_ ))), "cost is finite"
511
+ assert_equal (otda_unsup .coupling_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
512
+ assert np .all (np .isfinite (nx .to_numpy (otda_unsup .coupling_ ))), "coupling is finite"
498
513
n_unsup = nx .sum (otda_unsup .cost_ )
499
514
500
515
otda_semi = ot .da .EMDTransport ()
501
516
otda_semi .fit (Xs = Xs , ys = ys , Xt = Xt , yt = yt )
502
517
assert_equal (otda_semi .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
518
+ assert not np .any (np .isnan (nx .to_numpy (otda_semi .cost_ ))), "cost is finite"
519
+ assert_equal (otda_semi .coupling_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
520
+ assert np .all (np .isfinite (nx .to_numpy (otda_semi .coupling_ ))), "coupling is finite"
503
521
n_semisup = nx .sum (otda_semi .cost_ )
504
522
505
523
# check that the cost matrix norms are indeed different
506
- assert n_unsup != n_semisup , "semisupervised mode not working"
524
+ assert np . allclose ( n_unsup , n_semisup , atol = 1e-7 ), "semisupervised mode is not working"
507
525
508
526
# check that the coupling forbids mass transport between labeled source
509
527
# and labeled target samples
0 commit comments