8
8
import ot
9
9
import pytest
10
10
11
- from scipy .misc import logsumexp
12
11
13
-
14
- @pytest .mark .parametrize ("method" , ["sinkhorn" , "sinkhorn_stabilized" ])
12
+ @pytest .mark .parametrize ("method" , ["sinkhorn" ])
15
13
def test_unbalanced_convergence (method ):
16
14
# test generalized sinkhorn for unbalanced OT
17
15
n = 100
@@ -25,34 +23,29 @@ def test_unbalanced_convergence(method):
25
23
26
24
M = ot .dist (x , x )
27
25
epsilon = 1.
28
- mu = 1.
26
+ alpha = 1.
27
+ K = np .exp (- M / epsilon )
29
28
30
- G , log = ot .unbalanced .sinkhorn_unbalanced (a , b , M , reg = epsilon , mu = mu ,
29
+ G , log = ot .unbalanced .sinkhorn_unbalanced (a , b , M , reg = epsilon , alpha = alpha ,
31
30
stopThr = 1e-10 , method = method ,
32
31
log = True )
33
- loss = ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , mu ,
32
+ loss = ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , alpha ,
34
33
method = method )
35
34
# check fixed point equations
36
- # in log-domain
37
- fi = mu / (mu + epsilon )
38
- logb = np .log (b + 1e-16 )
39
- loga = np .log (a + 1e-16 )
40
- logKtu = logsumexp (log ["logu" ][None , :] - M .T / epsilon , axis = 1 )
41
- logKv = logsumexp (log ["logv" ][None , :] - M / epsilon , axis = 1 )
42
-
43
- v_final = fi * (logb - logKtu )
44
- u_final = fi * (loga - logKv )
35
+ fi = alpha / (alpha + epsilon )
36
+ v_final = (b / K .T .dot (log ["u" ])) ** fi
37
+ u_final = (a / K .dot (log ["v" ])) ** fi
45
38
46
39
np .testing .assert_allclose (
47
- u_final , log ["logu " ], atol = 1e-05 )
40
+ u_final , log ["u " ], atol = 1e-05 )
48
41
np .testing .assert_allclose (
49
- v_final , log ["logv " ], atol = 1e-05 )
42
+ v_final , log ["v " ], atol = 1e-05 )
50
43
51
44
# check if sinkhorn_unbalanced2 returns the correct loss
52
45
np .testing .assert_allclose ((G * M ).sum (), loss , atol = 1e-5 )
53
46
54
47
55
- @pytest .mark .parametrize ("method" , ["sinkhorn" , "sinkhorn_stabilized" ])
48
+ @pytest .mark .parametrize ("method" , ["sinkhorn" ])
56
49
def test_unbalanced_multiple_inputs (method ):
57
50
# test generalized sinkhorn for unbalanced OT
58
51
n = 100
@@ -66,55 +59,27 @@ def test_unbalanced_multiple_inputs(method):
66
59
67
60
M = ot .dist (x , x )
68
61
epsilon = 1.
69
- mu = 1.
62
+ alpha = 1.
63
+ K = np .exp (- M / epsilon )
70
64
71
- loss , log = ot .unbalanced .sinkhorn_unbalanced (a , b , M , reg = epsilon , mu = mu ,
65
+ loss , log = ot .unbalanced .sinkhorn_unbalanced (a , b , M , reg = epsilon ,
66
+ alpha = alpha ,
72
67
stopThr = 1e-10 , method = method ,
73
68
log = True )
74
69
# check fixed point equations
75
- # in log-domain
76
- fi = mu / (mu + epsilon )
77
- logb = np .log (b + 1e-16 )
78
- loga = np .log (a + 1e-16 )[:, None ]
79
- logKtu = logsumexp (log ["logu" ][:, None , :] - M [:, :, None ] / epsilon ,
80
- axis = 0 )
81
- logKv = logsumexp (log ["logv" ][None , :] - M [:, :, None ] / epsilon , axis = 1 )
82
- v_final = fi * (logb - logKtu )
83
- u_final = fi * (loga - logKv )
70
+ fi = alpha / (alpha + epsilon )
71
+ v_final = (b / K .T .dot (log ["u" ])) ** fi
72
+
73
+ u_final = (a [:, None ] / K .dot (log ["v" ])) ** fi
84
74
85
75
np .testing .assert_allclose (
86
- u_final , log ["logu " ], atol = 1e-05 )
76
+ u_final , log ["u " ], atol = 1e-05 )
87
77
np .testing .assert_allclose (
88
- v_final , log ["logv " ], atol = 1e-05 )
78
+ v_final , log ["v " ], atol = 1e-05 )
89
79
90
80
assert len (loss ) == b .shape [1 ]
91
81
92
82
93
- def test_stabilized_vs_sinkhorn ():
94
- # test if stable version matches sinkhorn
95
- n = 100
96
-
97
- # Gaussian distributions
98
- a = ot .datasets .make_1D_gauss (n , m = 20 , s = 5 ) # m= mean, s= std
99
- b1 = ot .datasets .make_1D_gauss (n , m = 60 , s = 8 )
100
- b2 = ot .datasets .make_1D_gauss (n , m = 30 , s = 4 )
101
-
102
- # creating matrix A containing all distributions
103
- b = np .vstack ((b1 , b2 )).T
104
-
105
- M = ot .utils .dist0 (n )
106
- M /= np .median (M )
107
- epsilon = 0.1
108
- mu = 1.
109
- G , log = ot .unbalanced .sinkhorn_stabilized_unbalanced (a , b , M , reg = epsilon ,
110
- mu = mu ,
111
- log = True )
112
- G2 , log2 = ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , mu ,
113
- method = "sinkhorn" , log = True )
114
-
115
- np .testing .assert_allclose (G , G2 )
116
-
117
-
118
83
def test_unbalanced_barycenter ():
119
84
# test generalized sinkhorn for unbalanced OT barycenter
120
85
n = 100
@@ -127,30 +92,27 @@ def test_unbalanced_barycenter():
127
92
A = A * np .array ([1 , 2 ])[None , :]
128
93
M = ot .dist (x , x )
129
94
epsilon = 1.
130
- mu = 1.
95
+ alpha = 1.
96
+ K = np .exp (- M / epsilon )
131
97
132
- q , log = ot .unbalanced .barycenter_unbalanced (A , M , reg = epsilon , mu = mu ,
98
+ q , log = ot .unbalanced .barycenter_unbalanced (A , M , reg = epsilon , alpha = alpha ,
133
99
stopThr = 1e-10 ,
134
100
log = True )
135
101
# check fixed point equations
136
- fi = mu / (mu + epsilon )
137
- logA = np .log (A + 1e-16 )
138
- logq = np .log (q + 1e-16 )[:, None ]
139
- logKtu = logsumexp (log ["logu" ][:, None , :] - M [:, :, None ] / epsilon ,
140
- axis = 0 )
141
- logKv = logsumexp (log ["logv" ][None , :] - M [:, :, None ] / epsilon , axis = 1 )
142
- v_final = fi * (logq - logKtu )
143
- u_final = fi * (logA - logKv )
102
+ fi = alpha / (alpha + epsilon )
103
+ v_final = (q [:, None ] / K .T .dot (log ["u" ])) ** fi
104
+ u_final = (A / K .dot (log ["v" ])) ** fi
144
105
145
106
np .testing .assert_allclose (
146
- u_final , log ["logu " ], atol = 1e-05 )
107
+ u_final , log ["u " ], atol = 1e-05 )
147
108
np .testing .assert_allclose (
148
- v_final , log ["logv " ], atol = 1e-05 )
109
+ v_final , log ["v " ], atol = 1e-05 )
149
110
150
111
151
112
def test_implemented_methods ():
152
- IMPLEMENTED_METHODS = ['sinkhorn' , 'sinkhorn_stabilized' ]
153
- TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling' ]
113
+ IMPLEMENTED_METHODS = ['sinkhorn' ]
114
+ TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized' ,
115
+ 'sinkhorn_epsilon_scaling' ]
154
116
NOT_VALID_TOKENS = ['foo' ]
155
117
# test generalized sinkhorn for unbalanced OT barycenter
156
118
n = 3
@@ -164,21 +126,21 @@ def test_implemented_methods():
164
126
165
127
M = ot .dist (x , x )
166
128
epsilon = 1.
167
- mu = 1.
129
+ alpha = 1.
168
130
for method in IMPLEMENTED_METHODS :
169
- ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , mu ,
131
+ ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , alpha ,
170
132
method = method )
171
- ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , mu ,
133
+ ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , alpha ,
172
134
method = method )
173
135
with pytest .warns (UserWarning , match = 'not implemented' ):
174
136
for method in set (TO_BE_IMPLEMENTED_METHODS ):
175
- ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , mu ,
137
+ ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , alpha ,
176
138
method = method )
177
- ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , mu ,
139
+ ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , alpha ,
178
140
method = method )
179
141
with pytest .raises (ValueError ):
180
142
for method in set (NOT_VALID_TOKENS ):
181
- ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , mu ,
143
+ ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , alpha ,
182
144
method = method )
183
- ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , mu ,
145
+ ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , alpha ,
184
146
method = method )
0 commit comments