Skip to content

Commit 49c100d

Browse files
committed
test semi supervised mode ok written for all class | need different tolerance for EMDTransport
1 parent 8e4a793 commit 49c100d

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

test/test_da.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@ def test_sinkhorn_lpl1_transport_class():
6464
assert_equal(transp_Xs.shape, Xs.shape)
6565

6666
# test unsupervised vs semi-supervised mode
67-
otda_unsup = ot.da.SinkhornTransport()
68-
otda_unsup.fit(Xs=Xs, Xt=Xt)
67+
otda_unsup = ot.da.SinkhornLpl1Transport()
68+
otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
6969
n_unsup = np.sum(otda_unsup.cost_)
7070

71-
otda_semi = ot.da.SinkhornTransport()
71+
otda_semi = ot.da.SinkhornLpl1Transport()
7272
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
7373
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
7474
n_semisup = np.sum(otda_semi.cost_)
@@ -136,11 +136,11 @@ def test_sinkhorn_l1l2_transport_class():
136136
assert_equal(transp_Xs.shape, Xs.shape)
137137

138138
# test unsupervised vs semi-supervised mode
139-
otda_unsup = ot.da.SinkhornTransport()
140-
otda_unsup.fit(Xs=Xs, Xt=Xt)
139+
otda_unsup = ot.da.SinkhornL1l2Transport()
140+
otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
141141
n_unsup = np.sum(otda_unsup.cost_)
142142

143-
otda_semi = ot.da.SinkhornTransport()
143+
otda_semi = ot.da.SinkhornL1l2Transport()
144144
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
145145
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
146146
n_semisup = np.sum(otda_semi.cost_)
@@ -152,7 +152,9 @@ def test_sinkhorn_l1l2_transport_class():
152152
# and labeled target samples
153153
mass_semi = np.sum(
154154
otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
155-
assert mass_semi == 0, "semisupervised mode not working"
155+
mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]
156+
assert_allclose(mass_semi, np.zeros_like(mass_semi),
157+
rtol=1e-9, atol=1e-9)
156158

157159
# check everything runs well with log=True
158160
otda = ot.da.SinkhornL1l2Transport(log=True)
@@ -289,11 +291,11 @@ def test_emd_transport_class():
289291
assert_equal(transp_Xs.shape, Xs.shape)
290292

291293
# test unsupervised vs semi-supervised mode
292-
otda_unsup = ot.da.SinkhornTransport()
293-
otda_unsup.fit(Xs=Xs, Xt=Xt)
294+
otda_unsup = ot.da.EMDTransport()
295+
otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
294296
n_unsup = np.sum(otda_unsup.cost_)
295297

296-
otda_semi = ot.da.SinkhornTransport()
298+
otda_semi = ot.da.EMDTransport()
297299
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
298300
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
299301
n_semisup = np.sum(otda_semi.cost_)
@@ -305,7 +307,11 @@ def test_emd_transport_class():
305307
# and labeled target samples
306308
mass_semi = np.sum(
307309
otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
308-
assert mass_semi == 0, "semisupervised mode not working"
310+
mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]
311+
312+
# we need to use a small tolerance here, otherwise the test breaks
313+
assert_allclose(mass_semi, np.zeros_like(mass_semi),
314+
rtol=1e-2, atol=1e-2)
309315

310316

311317
def test_mapping_transport_class():
@@ -491,3 +497,4 @@ def test_otda():
491497
# test_sinkhorn_l1l2_transport_class()
492498
# test_sinkhorn_lpl1_transport_class()
493499
# test_mapping_transport_class()
500+

0 commit comments

Comments
 (0)