@@ -484,66 +484,3 @@ def test_linear_mapping_class():
484
484
Cst = np .cov (Xst .T )
485
485
486
486
np .testing .assert_allclose (Ct , Cst , rtol = 1e-2 , atol = 1e-2 )
487
-
488
-
489
- def test_otda ():
490
-
491
- n_samples = 150 # nb samples
492
- np .random .seed (0 )
493
-
494
- xs , ys = ot .datasets .make_data_classif ('3gauss' , n_samples )
495
- xt , yt = ot .datasets .make_data_classif ('3gauss2' , n_samples )
496
-
497
- a , b = ot .unif (n_samples ), ot .unif (n_samples )
498
-
499
- # LP problem
500
- da_emd = ot .da .OTDA () # init class
501
- da_emd .fit (xs , xt ) # fit distributions
502
- da_emd .interp () # interpolation of source samples
503
- da_emd .predict (xs ) # interpolation of source samples
504
-
505
- np .testing .assert_allclose (a , np .sum (da_emd .G , 1 ))
506
- np .testing .assert_allclose (b , np .sum (da_emd .G , 0 ))
507
-
508
- # sinkhorn regularization
509
- lambd = 1e-1
510
- da_entrop = ot .da .OTDA_sinkhorn ()
511
- da_entrop .fit (xs , xt , reg = lambd )
512
- da_entrop .interp ()
513
- da_entrop .predict (xs )
514
-
515
- np .testing .assert_allclose (
516
- a , np .sum (da_entrop .G , 1 ), rtol = 1e-3 , atol = 1e-3 )
517
- np .testing .assert_allclose (b , np .sum (da_entrop .G , 0 ), rtol = 1e-3 , atol = 1e-3 )
518
-
519
- # non-convex Group lasso regularization
520
- reg = 1e-1
521
- eta = 1e0
522
- da_lpl1 = ot .da .OTDA_lpl1 ()
523
- da_lpl1 .fit (xs , ys , xt , reg = reg , eta = eta )
524
- da_lpl1 .interp ()
525
- da_lpl1 .predict (xs )
526
-
527
- np .testing .assert_allclose (a , np .sum (da_lpl1 .G , 1 ), rtol = 1e-3 , atol = 1e-3 )
528
- np .testing .assert_allclose (b , np .sum (da_lpl1 .G , 0 ), rtol = 1e-3 , atol = 1e-3 )
529
-
530
- # True Group lasso regularization
531
- reg = 1e-1
532
- eta = 2e0
533
- da_l1l2 = ot .da .OTDA_l1l2 ()
534
- da_l1l2 .fit (xs , ys , xt , reg = reg , eta = eta , numItermax = 20 , verbose = True )
535
- da_l1l2 .interp ()
536
- da_l1l2 .predict (xs )
537
-
538
- np .testing .assert_allclose (a , np .sum (da_l1l2 .G , 1 ), rtol = 1e-3 , atol = 1e-3 )
539
- np .testing .assert_allclose (b , np .sum (da_l1l2 .G , 0 ), rtol = 1e-3 , atol = 1e-3 )
540
-
541
- # linear mapping
542
- da_emd = ot .da .OTDA_mapping_linear () # init class
543
- da_emd .fit (xs , xt , numItermax = 10 ) # fit distributions
544
- da_emd .predict (xs ) # interpolation of source samples
545
-
546
- # nonlinear mapping
547
- da_emd = ot .da .OTDA_mapping_kernel () # init class
548
- da_emd .fit (xs , xt , numItermax = 10 ) # fit distributions
549
- da_emd .predict (xs ) # interpolation of source samples
0 commit comments