Skip to content

Commit 2097116

Browse files
committed
solving pb
1 parent 49c100d commit 2097116

File tree

1 file changed

+34
-27
lines changed

1 file changed

+34
-27
lines changed

test/test_da.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ def test_sinkhorn_lpl1_transport_class():
3636
# test margin constraints
3737
mu_s = unif(ns)
3838
mu_t = unif(nt)
39-
assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
40-
assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
39+
assert_allclose(
40+
np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
41+
assert_allclose(
42+
np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
4143

4244
# test transform
4345
transp_Xs = otda.transform(Xs=Xs)
@@ -108,8 +110,10 @@ def test_sinkhorn_l1l2_transport_class():
108110
# test margin constraints
109111
mu_s = unif(ns)
110112
mu_t = unif(nt)
111-
assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
112-
assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
113+
assert_allclose(
114+
np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
115+
assert_allclose(
116+
np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
113117

114118
# test transform
115119
transp_Xs = otda.transform(Xs=Xs)
@@ -187,8 +191,10 @@ def test_sinkhorn_transport_class():
187191
# test margin constraints
188192
mu_s = unif(ns)
189193
mu_t = unif(nt)
190-
assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
191-
assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
194+
assert_allclose(
195+
np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
196+
assert_allclose(
197+
np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
192198

193199
# test transform
194200
transp_Xs = otda.transform(Xs=Xs)
@@ -263,8 +269,10 @@ def test_emd_transport_class():
263269
# test margin constraints
264270
mu_s = unif(ns)
265271
mu_t = unif(nt)
266-
assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
267-
assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
272+
assert_allclose(
273+
np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
274+
assert_allclose(
275+
np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
268276

269277
# test transform
270278
transp_Xs = otda.transform(Xs=Xs)
@@ -342,8 +350,10 @@ def test_mapping_transport_class():
342350
# test margin constraints
343351
mu_s = unif(ns)
344352
mu_t = unif(nt)
345-
assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
346-
assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
353+
assert_allclose(
354+
np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
355+
assert_allclose(
356+
np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
347357

348358
# test transform
349359
transp_Xs = otda.transform(Xs=Xs)
@@ -363,8 +373,10 @@ def test_mapping_transport_class():
363373
# test margin constraints
364374
mu_s = unif(ns)
365375
mu_t = unif(nt)
366-
assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
367-
assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
376+
assert_allclose(
377+
np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
378+
assert_allclose(
379+
np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
368380

369381
# test transform
370382
transp_Xs = otda.transform(Xs=Xs)
@@ -389,8 +401,10 @@ def test_mapping_transport_class():
389401
# test margin constraints
390402
mu_s = unif(ns)
391403
mu_t = unif(nt)
392-
assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
393-
assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
404+
assert_allclose(
405+
np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
406+
assert_allclose(
407+
np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
394408

395409
# test transform
396410
transp_Xs = otda.transform(Xs=Xs)
@@ -410,8 +424,10 @@ def test_mapping_transport_class():
410424
# test margin constraints
411425
mu_s = unif(ns)
412426
mu_t = unif(nt)
413-
assert_allclose(np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
414-
assert_allclose(np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
427+
assert_allclose(
428+
np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
429+
assert_allclose(
430+
np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
415431

416432
# test transform
417433
transp_Xs = otda.transform(Xs=Xs)
@@ -454,7 +470,8 @@ def test_otda():
454470
da_entrop.interp()
455471
da_entrop.predict(xs)
456472

457-
np.testing.assert_allclose(a, np.sum(da_entrop.G, 1), rtol=1e-3, atol=1e-3)
473+
np.testing.assert_allclose(
474+
a, np.sum(da_entrop.G, 1), rtol=1e-3, atol=1e-3)
458475
np.testing.assert_allclose(b, np.sum(da_entrop.G, 0), rtol=1e-3, atol=1e-3)
459476

460477
# non-convex Group lasso regularization
@@ -488,13 +505,3 @@ def test_otda():
488505
da_emd = ot.da.OTDA_mapping_kernel() # init class
489506
da_emd.fit(xs, xt, numItermax=10) # fit distributions
490507
da_emd.predict(xs) # interpolation of source samples
491-
492-
493-
# if __name__ == "__main__":
494-
495-
# test_sinkhorn_transport_class()
496-
# test_emd_transport_class()
497-
# test_sinkhorn_l1l2_transport_class()
498-
# test_sinkhorn_lpl1_transport_class()
499-
# test_mapping_transport_class()
500-

0 commit comments

Comments
 (0)