8
8
import ot
9
9
import pytest
10
10
11
+ from scipy .misc import logsumexp
11
12
12
- @pytest .mark .parametrize ("method" , ["sinkhorn" ])
13
+
14
+ @pytest .mark .parametrize ("method" , ["sinkhorn" , "sinkhorn_stabilized" ])
13
15
def test_unbalanced_convergence (method ):
14
16
# test generalized sinkhorn for unbalanced OT
15
17
n = 100
@@ -23,29 +25,34 @@ def test_unbalanced_convergence(method):
23
25
24
26
M = ot .dist (x , x )
25
27
epsilon = 1.
26
- alpha = 1.
27
- K = np .exp (- M / epsilon )
28
+ mu = 1.
28
29
29
- G , log = ot .unbalanced .sinkhorn_unbalanced (a , b , M , reg = epsilon , alpha = alpha ,
30
+ G , log = ot .unbalanced .sinkhorn_unbalanced (a , b , M , reg = epsilon , mu = mu ,
30
31
stopThr = 1e-10 , method = method ,
31
32
log = True )
32
- loss = ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , alpha ,
33
+ loss = ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , mu ,
33
34
method = method )
34
35
# check fixed point equations
35
- fi = alpha / (alpha + epsilon )
36
- v_final = (b / K .T .dot (log ["u" ])) ** fi
37
- u_final = (a / K .dot (log ["v" ])) ** fi
36
+ # in log-domain
37
+ fi = mu / (mu + epsilon )
38
+ logb = np .log (b + 1e-16 )
39
+ loga = np .log (a + 1e-16 )
40
+ logKtu = logsumexp (log ["logu" ][None , :] - M .T / epsilon , axis = 1 )
41
+ logKv = logsumexp (log ["logv" ][None , :] - M / epsilon , axis = 1 )
42
+
43
+ v_final = fi * (logb - logKtu )
44
+ u_final = fi * (loga - logKv )
38
45
39
46
np .testing .assert_allclose (
40
- u_final , log ["u " ], atol = 1e-05 )
47
+ u_final , log ["logu " ], atol = 1e-05 )
41
48
np .testing .assert_allclose (
42
- v_final , log ["v " ], atol = 1e-05 )
49
+ v_final , log ["logv " ], atol = 1e-05 )
43
50
44
51
# check if sinkhorn_unbalanced2 returns the correct loss
45
52
np .testing .assert_allclose ((G * M ).sum (), loss , atol = 1e-5 )
46
53
47
54
48
- @pytest .mark .parametrize ("method" , ["sinkhorn" ])
55
+ @pytest .mark .parametrize ("method" , ["sinkhorn" , "sinkhorn_stabilized" ])
49
56
def test_unbalanced_multiple_inputs (method ):
50
57
# test generalized sinkhorn for unbalanced OT
51
58
n = 100
@@ -59,27 +66,55 @@ def test_unbalanced_multiple_inputs(method):
59
66
60
67
M = ot .dist (x , x )
61
68
epsilon = 1.
62
- alpha = 1.
63
- K = np .exp (- M / epsilon )
69
+ mu = 1.
64
70
65
- loss , log = ot .unbalanced .sinkhorn_unbalanced (a , b , M , reg = epsilon ,
66
- alpha = alpha ,
71
+ loss , log = ot .unbalanced .sinkhorn_unbalanced (a , b , M , reg = epsilon , mu = mu ,
67
72
stopThr = 1e-10 , method = method ,
68
73
log = True )
69
74
# check fixed point equations
70
- fi = alpha / (alpha + epsilon )
71
- v_final = (b / K .T .dot (log ["u" ])) ** fi
72
-
73
- u_final = (a [:, None ] / K .dot (log ["v" ])) ** fi
75
+ # in log-domain
76
+ fi = mu / (mu + epsilon )
77
+ logb = np .log (b + 1e-16 )
78
+ loga = np .log (a + 1e-16 )[:, None ]
79
+ logKtu = logsumexp (log ["logu" ][:, None , :] - M [:, :, None ] / epsilon ,
80
+ axis = 0 )
81
+ logKv = logsumexp (log ["logv" ][None , :] - M [:, :, None ] / epsilon , axis = 1 )
82
+ v_final = fi * (logb - logKtu )
83
+ u_final = fi * (loga - logKv )
74
84
75
85
np .testing .assert_allclose (
76
- u_final , log ["u " ], atol = 1e-05 )
86
+ u_final , log ["logu " ], atol = 1e-05 )
77
87
np .testing .assert_allclose (
78
- v_final , log ["v " ], atol = 1e-05 )
88
+ v_final , log ["logv " ], atol = 1e-05 )
79
89
80
90
assert len (loss ) == b .shape [1 ]
81
91
82
92
93
+ def test_stabilized_vs_sinkhorn ():
94
+ # test if stable version matches sinkhorn
95
+ n = 100
96
+
97
+ # Gaussian distributions
98
+ a = ot .datasets .make_1D_gauss (n , m = 20 , s = 5 ) # m= mean, s= std
99
+ b1 = ot .datasets .make_1D_gauss (n , m = 60 , s = 8 )
100
+ b2 = ot .datasets .make_1D_gauss (n , m = 30 , s = 4 )
101
+
102
+ # creating matrix A containing all distributions
103
+ b = np .vstack ((b1 , b2 )).T
104
+
105
+ M = ot .utils .dist0 (n )
106
+ M /= np .median (M )
107
+ epsilon = 0.1
108
+ mu = 1.
109
+ G , log = ot .unbalanced .sinkhorn_stabilized_unbalanced (a , b , M , reg = epsilon ,
110
+ mu = mu ,
111
+ log = True )
112
+ G2 , log2 = ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , mu ,
113
+ method = "sinkhorn" , log = True )
114
+
115
+ np .testing .assert_allclose (G , G2 )
116
+
117
+
83
118
def test_unbalanced_barycenter ():
84
119
# test generalized sinkhorn for unbalanced OT barycenter
85
120
n = 100
@@ -92,27 +127,30 @@ def test_unbalanced_barycenter():
92
127
A = A * np .array ([1 , 2 ])[None , :]
93
128
M = ot .dist (x , x )
94
129
epsilon = 1.
95
- alpha = 1.
96
- K = np .exp (- M / epsilon )
130
+ mu = 1.
97
131
98
- q , log = ot .unbalanced .barycenter_unbalanced (A , M , reg = epsilon , alpha = alpha ,
132
+ q , log = ot .unbalanced .barycenter_unbalanced (A , M , reg = epsilon , mu = mu ,
99
133
stopThr = 1e-10 ,
100
134
log = True )
101
135
# check fixed point equations
102
- fi = alpha / (alpha + epsilon )
103
- v_final = (q [:, None ] / K .T .dot (log ["u" ])) ** fi
104
- u_final = (A / K .dot (log ["v" ])) ** fi
136
+ fi = mu / (mu + epsilon )
137
+ logA = np .log (A + 1e-16 )
138
+ logq = np .log (q + 1e-16 )[:, None ]
139
+ logKtu = logsumexp (log ["logu" ][:, None , :] - M [:, :, None ] / epsilon ,
140
+ axis = 0 )
141
+ logKv = logsumexp (log ["logv" ][None , :] - M [:, :, None ] / epsilon , axis = 1 )
142
+ v_final = fi * (logq - logKtu )
143
+ u_final = fi * (logA - logKv )
105
144
106
145
np .testing .assert_allclose (
107
- u_final , log ["u " ], atol = 1e-05 )
146
+ u_final , log ["logu " ], atol = 1e-05 )
108
147
np .testing .assert_allclose (
109
- v_final , log ["v " ], atol = 1e-05 )
148
+ v_final , log ["logv " ], atol = 1e-05 )
110
149
111
150
112
151
def test_implemented_methods ():
113
- IMPLEMENTED_METHODS = ['sinkhorn' ]
114
- TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized' ,
115
- 'sinkhorn_epsilon_scaling' ]
152
+ IMPLEMENTED_METHODS = ['sinkhorn' , 'sinkhorn_stabilized' ]
153
+ TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling' ]
116
154
NOT_VALID_TOKENS = ['foo' ]
117
155
# test generalized sinkhorn for unbalanced OT barycenter
118
156
n = 3
@@ -126,21 +164,21 @@ def test_implemented_methods():
126
164
127
165
M = ot .dist (x , x )
128
166
epsilon = 1.
129
- alpha = 1.
167
+ mu = 1.
130
168
for method in IMPLEMENTED_METHODS :
131
- ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , alpha ,
169
+ ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , mu ,
132
170
method = method )
133
- ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , alpha ,
171
+ ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , mu ,
134
172
method = method )
135
173
with pytest .warns (UserWarning , match = 'not implemented' ):
136
174
for method in set (TO_BE_IMPLEMENTED_METHODS ):
137
- ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , alpha ,
175
+ ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , mu ,
138
176
method = method )
139
- ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , alpha ,
177
+ ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , mu ,
140
178
method = method )
141
179
with pytest .raises (ValueError ):
142
180
for method in set (NOT_VALID_TOKENS ):
143
- ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , alpha ,
181
+ ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , mu ,
144
182
method = method )
145
- ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , alpha ,
183
+ ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , mu ,
146
184
method = method )
0 commit comments