@@ -36,7 +36,8 @@ def test_emd_emd2():
36
36
37
37
# check loss=0
38
38
assert np .allclose (w , 0 )
39
-
39
+
40
+
40
41
def test_emd2_multi ():
41
42
42
43
from ot .datasets import get_1D_gauss as gauss
@@ -72,11 +73,11 @@ def test_emd2_multi():
72
73
emdn = ot .emd2 (a , b , M )
73
74
ot .toc ('multi proc : {} s' )
74
75
75
- assert np .allclose (emd1 , emdn )
76
-
77
-
76
+ assert np .allclose (emd1 , emdn )
77
+
78
+
78
79
def test_sinkhorn ():
79
- # test sinkhorn
80
+ # test sinkhorn
80
81
n = 100
81
82
np .random .seed (0 )
82
83
@@ -85,14 +86,15 @@ def test_sinkhorn():
85
86
86
87
M = ot .dist (x , x )
87
88
88
- G = ot .sinkhorn (u , u , M ,1 , stopThr = 1e-10 )
89
+ G = ot .sinkhorn (u , u , M , 1 , stopThr = 1e-10 )
89
90
90
91
# check constratints
91
- assert np .allclose (u , G .sum (1 ), atol = 1e-05 ) # cf convergence sinkhorn
92
- assert np .allclose (u , G .sum (0 ), atol = 1e-05 ) # cf convergence sinkhorn
93
-
92
+ assert np .allclose (u , G .sum (1 ), atol = 1e-05 ) # cf convergence sinkhorn
93
+ assert np .allclose (u , G .sum (0 ), atol = 1e-05 ) # cf convergence sinkhorn
94
+
95
+
94
96
def test_sinkhorn_variants ():
95
- # test sinkhorn
97
+ # test sinkhorn
96
98
n = 100
97
99
np .random .seed (0 )
98
100
@@ -101,11 +103,11 @@ def test_sinkhorn_variants():
101
103
102
104
M = ot .dist (x , x )
103
105
104
- G0 = ot .sinkhorn (u , u , M ,1 , method = 'sinkhorn' ,stopThr = 1e-10 )
105
- Gs = ot .sinkhorn (u , u , M ,1 , method = 'sinkhorn_stabilized' ,stopThr = 1e-10 )
106
- Ges = ot .sinkhorn (u , u , M ,1 , method = 'sinkhorn_epsilon_scaling' ,stopThr = 1e-10 )
106
+ G0 = ot .sinkhorn (u , u , M , 1 , method = 'sinkhorn' , stopThr = 1e-10 )
107
+ Gs = ot .sinkhorn (u , u , M , 1 , method = 'sinkhorn_stabilized' , stopThr = 1e-10 )
108
+ Ges = ot .sinkhorn (
109
+ u , u , M , 1 , method = 'sinkhorn_epsilon_scaling' , stopThr = 1e-10 )
107
110
108
111
# check constratints
109
- assert np .allclose (G0 , Gs , atol = 1e-05 )
110
- assert np .allclose (G0 , Ges , atol = 1e-05 ) #
111
-
112
+ assert np .allclose (G0 , Gs , atol = 1e-05 )
113
+ assert np .allclose (G0 , Ges , atol = 1e-05 )
0 commit comments