Skip to content

Commit 63b34bf

Browse files
Kilian FatrasKilian Fatras
authored andcommitted
fixed conflicts
1 parent fd6371c commit 63b34bf

File tree

1 file changed

+0
-63
lines changed

1 file changed

+0
-63
lines changed

test/test_da.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -484,66 +484,3 @@ def test_linear_mapping_class():
484484
Cst = np.cov(Xst.T)
485485

486486
np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
487-
488-
489-
def test_otda():
490-
491-
n_samples = 150 # nb samples
492-
np.random.seed(0)
493-
494-
xs, ys = ot.datasets.make_data_classif('3gauss', n_samples)
495-
xt, yt = ot.datasets.make_data_classif('3gauss2', n_samples)
496-
497-
a, b = ot.unif(n_samples), ot.unif(n_samples)
498-
499-
# LP problem
500-
da_emd = ot.da.OTDA() # init class
501-
da_emd.fit(xs, xt) # fit distributions
502-
da_emd.interp() # interpolation of source samples
503-
da_emd.predict(xs) # interpolation of source samples
504-
505-
np.testing.assert_allclose(a, np.sum(da_emd.G, 1))
506-
np.testing.assert_allclose(b, np.sum(da_emd.G, 0))
507-
508-
# sinkhorn regularization
509-
lambd = 1e-1
510-
da_entrop = ot.da.OTDA_sinkhorn()
511-
da_entrop.fit(xs, xt, reg=lambd)
512-
da_entrop.interp()
513-
da_entrop.predict(xs)
514-
515-
np.testing.assert_allclose(
516-
a, np.sum(da_entrop.G, 1), rtol=1e-3, atol=1e-3)
517-
np.testing.assert_allclose(b, np.sum(da_entrop.G, 0), rtol=1e-3, atol=1e-3)
518-
519-
# non-convex Group lasso regularization
520-
reg = 1e-1
521-
eta = 1e0
522-
da_lpl1 = ot.da.OTDA_lpl1()
523-
da_lpl1.fit(xs, ys, xt, reg=reg, eta=eta)
524-
da_lpl1.interp()
525-
da_lpl1.predict(xs)
526-
527-
np.testing.assert_allclose(a, np.sum(da_lpl1.G, 1), rtol=1e-3, atol=1e-3)
528-
np.testing.assert_allclose(b, np.sum(da_lpl1.G, 0), rtol=1e-3, atol=1e-3)
529-
530-
# True Group lasso regularization
531-
reg = 1e-1
532-
eta = 2e0
533-
da_l1l2 = ot.da.OTDA_l1l2()
534-
da_l1l2.fit(xs, ys, xt, reg=reg, eta=eta, numItermax=20, verbose=True)
535-
da_l1l2.interp()
536-
da_l1l2.predict(xs)
537-
538-
np.testing.assert_allclose(a, np.sum(da_l1l2.G, 1), rtol=1e-3, atol=1e-3)
539-
np.testing.assert_allclose(b, np.sum(da_l1l2.G, 0), rtol=1e-3, atol=1e-3)
540-
541-
# linear mapping
542-
da_emd = ot.da.OTDA_mapping_linear() # init class
543-
da_emd.fit(xs, xt, numItermax=10) # fit distributions
544-
da_emd.predict(xs) # interpolation of source samples
545-
546-
# nonlinear mapping
547-
da_emd = ot.da.OTDA_mapping_kernel() # init class
548-
da_emd.fit(xs, xt, numItermax=10) # fit distributions
549-
da_emd.predict(xs) # interpolation of source samples

0 commit comments

Comments
 (0)