@@ -18,8 +18,9 @@ def test_doctest():
18
18
19
19
20
20
def test_emd_emd2 ():
21
- # test emd
21
+ # test emd and emd2 for simple identity
22
22
n = 100
23
+ np .random .seed (0 )
23
24
24
25
x = np .random .randn (n , 2 )
25
26
u = ot .utils .unif (n )
@@ -35,14 +36,13 @@ def test_emd_emd2():
35
36
36
37
# check loss=0
37
38
assert np .allclose (w , 0 )
38
-
39
-
40
- #@pytest.mark.skip(reason="Seems to be a conflict between pytest and multiprocessing")
39
+
41
40
def test_emd2_multi ():
42
41
43
42
from ot .datasets import get_1D_gauss as gauss
44
43
45
44
n = 1000 # nb bins
45
+ np .random .seed (0 )
46
46
47
47
# bin positions
48
48
x = np .arange (n , dtype = np .float64 )
@@ -72,4 +72,40 @@ def test_emd2_multi():
72
72
emdn = ot .emd2 (a , b , M )
73
73
ot .toc ('multi proc : {} s' )
74
74
75
- assert np .allclose (emd1 , emdn )
75
+ assert np .allclose (emd1 , emdn )
76
+
77
+
78
+ def test_sinkhorn ():
79
+ # test sinkhorn
80
+ n = 100
81
+ np .random .seed (0 )
82
+
83
+ x = np .random .randn (n , 2 )
84
+ u = ot .utils .unif (n )
85
+
86
+ M = ot .dist (x , x )
87
+
88
+ G = ot .sinkhorn (u , u , M ,1 ,stopThr = 1e-10 )
89
+
90
+ # check constratints
91
+ assert np .allclose (u , G .sum (1 ), atol = 1e-05 ) # cf convergence sinkhorn
92
+ assert np .allclose (u , G .sum (0 ), atol = 1e-05 ) # cf convergence sinkhorn
93
+
94
+ def test_sinkhorn_variants ():
95
+ # test sinkhorn
96
+ n = 100
97
+ np .random .seed (0 )
98
+
99
+ x = np .random .randn (n , 2 )
100
+ u = ot .utils .unif (n )
101
+
102
+ M = ot .dist (x , x )
103
+
104
+ G0 = ot .sinkhorn (u , u , M ,1 , method = 'sinkhorn' ,stopThr = 1e-10 )
105
+ Gs = ot .sinkhorn (u , u , M ,1 , method = 'sinkhorn_stabilized' ,stopThr = 1e-10 )
106
+ Ges = ot .sinkhorn (u , u , M ,1 , method = 'sinkhorn_epsilon_scaling' ,stopThr = 1e-10 )
107
+
108
+ # check constratints
109
+ assert np .allclose (G0 , Gs , atol = 1e-05 )
110
+ assert np .allclose (G0 , Ges , atol = 1e-05 ) #
111
+
0 commit comments