@@ -64,11 +64,11 @@ def test_sinkhorn_lpl1_transport_class():
64
64
assert_equal (transp_Xs .shape , Xs .shape )
65
65
66
66
# test unsupervised vs semi-supervised mode
67
- otda_unsup = ot .da .SinkhornTransport ()
68
- otda_unsup .fit (Xs = Xs , Xt = Xt )
67
+ otda_unsup = ot .da .SinkhornLpl1Transport ()
68
+ otda_unsup .fit (Xs = Xs , ys = ys , Xt = Xt )
69
69
n_unsup = np .sum (otda_unsup .cost_ )
70
70
71
- otda_semi = ot .da .SinkhornTransport ()
71
+ otda_semi = ot .da .SinkhornLpl1Transport ()
72
72
otda_semi .fit (Xs = Xs , ys = ys , Xt = Xt , yt = yt )
73
73
assert_equal (otda_semi .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
74
74
n_semisup = np .sum (otda_semi .cost_ )
@@ -136,11 +136,11 @@ def test_sinkhorn_l1l2_transport_class():
136
136
assert_equal (transp_Xs .shape , Xs .shape )
137
137
138
138
# test unsupervised vs semi-supervised mode
139
- otda_unsup = ot .da .SinkhornTransport ()
140
- otda_unsup .fit (Xs = Xs , Xt = Xt )
139
+ otda_unsup = ot .da .SinkhornL1l2Transport ()
140
+ otda_unsup .fit (Xs = Xs , ys = ys , Xt = Xt )
141
141
n_unsup = np .sum (otda_unsup .cost_ )
142
142
143
- otda_semi = ot .da .SinkhornTransport ()
143
+ otda_semi = ot .da .SinkhornL1l2Transport ()
144
144
otda_semi .fit (Xs = Xs , ys = ys , Xt = Xt , yt = yt )
145
145
assert_equal (otda_semi .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
146
146
n_semisup = np .sum (otda_semi .cost_ )
@@ -152,7 +152,9 @@ def test_sinkhorn_l1l2_transport_class():
152
152
# and labeled target samples
153
153
mass_semi = np .sum (
154
154
otda_semi .coupling_ [otda_semi .cost_ == otda_semi .limit_max ])
155
- assert mass_semi == 0 , "semisupervised mode not working"
155
+ mass_semi = otda_semi .coupling_ [otda_semi .cost_ == otda_semi .limit_max ]
156
+ assert_allclose (mass_semi , np .zeros_like (mass_semi ),
157
+ rtol = 1e-9 , atol = 1e-9 )
156
158
157
159
# check everything runs well with log=True
158
160
otda = ot .da .SinkhornL1l2Transport (log = True )
@@ -289,11 +291,11 @@ def test_emd_transport_class():
289
291
assert_equal (transp_Xs .shape , Xs .shape )
290
292
291
293
# test unsupervised vs semi-supervised mode
292
- otda_unsup = ot .da .SinkhornTransport ()
293
- otda_unsup .fit (Xs = Xs , Xt = Xt )
294
+ otda_unsup = ot .da .EMDTransport ()
295
+ otda_unsup .fit (Xs = Xs , ys = ys , Xt = Xt )
294
296
n_unsup = np .sum (otda_unsup .cost_ )
295
297
296
- otda_semi = ot .da .SinkhornTransport ()
298
+ otda_semi = ot .da .EMDTransport ()
297
299
otda_semi .fit (Xs = Xs , ys = ys , Xt = Xt , yt = yt )
298
300
assert_equal (otda_semi .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
299
301
n_semisup = np .sum (otda_semi .cost_ )
@@ -305,7 +307,11 @@ def test_emd_transport_class():
305
307
# and labeled target samples
306
308
mass_semi = np .sum (
307
309
otda_semi .coupling_ [otda_semi .cost_ == otda_semi .limit_max ])
308
- assert mass_semi == 0 , "semisupervised mode not working"
310
+ mass_semi = otda_semi .coupling_ [otda_semi .cost_ == otda_semi .limit_max ]
311
+
312
+ # we need to use a small tolerance here, otherwise the test breaks
313
+ assert_allclose (mass_semi , np .zeros_like (mass_semi ),
314
+ rtol = 1e-2 , atol = 1e-2 )
309
315
310
316
311
317
def test_mapping_transport_class ():
@@ -491,3 +497,4 @@ def test_otda():
491
497
# test_sinkhorn_l1l2_transport_class()
492
498
# test_sinkhorn_lpl1_transport_class()
493
499
# test_mapping_transport_class()
500
+
0 commit comments