Skip to content

Commit 5c0ed10

Browse files
author
Hicham Janati
committed
add unbalanced tests with stabilization
1 parent 10accb1 commit 5c0ed10

File tree

1 file changed

+77
-39
lines changed

1 file changed

+77
-39
lines changed

test/test_unbalanced.py

Lines changed: 77 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import ot
99
import pytest
1010

11+
from scipy.misc import logsumexp
1112

12-
@pytest.mark.parametrize("method", ["sinkhorn"])
13+
14+
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
1315
def test_unbalanced_convergence(method):
1416
# test generalized sinkhorn for unbalanced OT
1517
n = 100
@@ -23,29 +25,34 @@ def test_unbalanced_convergence(method):
2325

2426
M = ot.dist(x, x)
2527
epsilon = 1.
26-
alpha = 1.
27-
K = np.exp(- M / epsilon)
28+
mu = 1.
2829

29-
G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha,
30+
G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, mu=mu,
3031
stopThr=1e-10, method=method,
3132
log=True)
32-
loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
33+
loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu,
3334
method=method)
3435
# check fixed point equations
35-
fi = alpha / (alpha + epsilon)
36-
v_final = (b / K.T.dot(log["u"])) ** fi
37-
u_final = (a / K.dot(log["v"])) ** fi
36+
# in log-domain
37+
fi = mu / (mu + epsilon)
38+
logb = np.log(b + 1e-16)
39+
loga = np.log(a + 1e-16)
40+
logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1)
41+
logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1)
42+
43+
v_final = fi * (logb - logKtu)
44+
u_final = fi * (loga - logKv)
3845

3946
np.testing.assert_allclose(
40-
u_final, log["u"], atol=1e-05)
47+
u_final, log["logu"], atol=1e-05)
4148
np.testing.assert_allclose(
42-
v_final, log["v"], atol=1e-05)
49+
v_final, log["logv"], atol=1e-05)
4350

4451
# check if sinkhorn_unbalanced2 returns the correct loss
4552
np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5)
4653

4754

48-
@pytest.mark.parametrize("method", ["sinkhorn"])
55+
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
4956
def test_unbalanced_multiple_inputs(method):
5057
# test generalized sinkhorn for unbalanced OT
5158
n = 100
@@ -59,27 +66,55 @@ def test_unbalanced_multiple_inputs(method):
5966

6067
M = ot.dist(x, x)
6168
epsilon = 1.
62-
alpha = 1.
63-
K = np.exp(- M / epsilon)
69+
mu = 1.
6470

65-
loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
66-
alpha=alpha,
71+
loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, mu=mu,
6772
stopThr=1e-10, method=method,
6873
log=True)
6974
# check fixed point equations
70-
fi = alpha / (alpha + epsilon)
71-
v_final = (b / K.T.dot(log["u"])) ** fi
72-
73-
u_final = (a[:, None] / K.dot(log["v"])) ** fi
75+
# in log-domain
76+
fi = mu / (mu + epsilon)
77+
logb = np.log(b + 1e-16)
78+
loga = np.log(a + 1e-16)[:, None]
79+
logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
80+
axis=0)
81+
logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
82+
v_final = fi * (logb - logKtu)
83+
u_final = fi * (loga - logKv)
7484

7585
np.testing.assert_allclose(
76-
u_final, log["u"], atol=1e-05)
86+
u_final, log["logu"], atol=1e-05)
7787
np.testing.assert_allclose(
78-
v_final, log["v"], atol=1e-05)
88+
v_final, log["logv"], atol=1e-05)
7989

8090
assert len(loss) == b.shape[1]
8191

8292

93+
def test_stabilized_vs_sinkhorn():
94+
# test if stable version matches sinkhorn
95+
n = 100
96+
97+
# Gaussian distributions
98+
a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
99+
b1 = ot.datasets.make_1D_gauss(n, m=60, s=8)
100+
b2 = ot.datasets.make_1D_gauss(n, m=30, s=4)
101+
102+
# creating matrix A containing all distributions
103+
b = np.vstack((b1, b2)).T
104+
105+
M = ot.utils.dist0(n)
106+
M /= np.median(M)
107+
epsilon = 0.1
108+
mu = 1.
109+
G, log = ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, reg=epsilon,
110+
mu=mu,
111+
log=True)
112+
G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu,
113+
method="sinkhorn", log=True)
114+
115+
np.testing.assert_allclose(G, G2)
116+
117+
83118
def test_unbalanced_barycenter():
84119
# test generalized sinkhorn for unbalanced OT barycenter
85120
n = 100
@@ -92,27 +127,30 @@ def test_unbalanced_barycenter():
92127
A = A * np.array([1, 2])[None, :]
93128
M = ot.dist(x, x)
94129
epsilon = 1.
95-
alpha = 1.
96-
K = np.exp(- M / epsilon)
130+
mu = 1.
97131

98-
q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha,
132+
q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, mu=mu,
99133
stopThr=1e-10,
100134
log=True)
101135
# check fixed point equations
102-
fi = alpha / (alpha + epsilon)
103-
v_final = (q[:, None] / K.T.dot(log["u"])) ** fi
104-
u_final = (A / K.dot(log["v"])) ** fi
136+
fi = mu / (mu + epsilon)
137+
logA = np.log(A + 1e-16)
138+
logq = np.log(q + 1e-16)[:, None]
139+
logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
140+
axis=0)
141+
logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
142+
v_final = fi * (logq - logKtu)
143+
u_final = fi * (logA - logKv)
105144

106145
np.testing.assert_allclose(
107-
u_final, log["u"], atol=1e-05)
146+
u_final, log["logu"], atol=1e-05)
108147
np.testing.assert_allclose(
109-
v_final, log["v"], atol=1e-05)
148+
v_final, log["logv"], atol=1e-05)
110149

111150

112151
def test_implemented_methods():
113-
IMPLEMENTED_METHODS = ['sinkhorn']
114-
TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized',
115-
'sinkhorn_epsilon_scaling']
152+
IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
153+
TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling']
116154
NOT_VALID_TOKENS = ['foo']
117155
# test generalized sinkhorn for unbalanced OT barycenter
118156
n = 3
@@ -126,21 +164,21 @@ def test_implemented_methods():
126164

127165
M = ot.dist(x, x)
128166
epsilon = 1.
129-
alpha = 1.
167+
mu = 1.
130168
for method in IMPLEMENTED_METHODS:
131-
ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
169+
ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu,
132170
method=method)
133-
ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
171+
ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu,
134172
method=method)
135173
with pytest.warns(UserWarning, match='not implemented'):
136174
for method in set(TO_BE_IMPLEMENTED_METHODS):
137-
ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
175+
ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu,
138176
method=method)
139-
ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
177+
ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu,
140178
method=method)
141179
with pytest.raises(ValueError):
142180
for method in set(NOT_VALID_TOKENS):
143-
ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
181+
ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu,
144182
method=method)
145-
ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
183+
ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu,
146184
method=method)

0 commit comments

Comments
 (0)