Skip to content

Commit a507556

Browse files
author
Hicham Janati
committed
rebase unbalanced
1 parent a725f1d commit a507556

File tree

1 file changed

+39
-77
lines changed

1 file changed

+39
-77
lines changed

test/test_unbalanced.py

Lines changed: 39 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,8 @@
88
import ot
99
import pytest
1010

11-
from scipy.misc import logsumexp
1211

13-
14-
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
12+
@pytest.mark.parametrize("method", ["sinkhorn"])
1513
def test_unbalanced_convergence(method):
1614
# test generalized sinkhorn for unbalanced OT
1715
n = 100
@@ -25,34 +23,29 @@ def test_unbalanced_convergence(method):
2523

2624
M = ot.dist(x, x)
2725
epsilon = 1.
28-
mu = 1.
26+
alpha = 1.
27+
K = np.exp(- M / epsilon)
2928

30-
G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, mu=mu,
29+
G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha,
3130
stopThr=1e-10, method=method,
3231
log=True)
33-
loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu,
32+
loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
3433
method=method)
3534
# check fixed point equations
36-
# in log-domain
37-
fi = mu / (mu + epsilon)
38-
logb = np.log(b + 1e-16)
39-
loga = np.log(a + 1e-16)
40-
logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1)
41-
logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1)
42-
43-
v_final = fi * (logb - logKtu)
44-
u_final = fi * (loga - logKv)
35+
fi = alpha / (alpha + epsilon)
36+
v_final = (b / K.T.dot(log["u"])) ** fi
37+
u_final = (a / K.dot(log["v"])) ** fi
4538

4639
np.testing.assert_allclose(
47-
u_final, log["logu"], atol=1e-05)
40+
u_final, log["u"], atol=1e-05)
4841
np.testing.assert_allclose(
49-
v_final, log["logv"], atol=1e-05)
42+
v_final, log["v"], atol=1e-05)
5043

5144
# check if sinkhorn_unbalanced2 returns the correct loss
5245
np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5)
5346

5447

55-
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
48+
@pytest.mark.parametrize("method", ["sinkhorn"])
5649
def test_unbalanced_multiple_inputs(method):
5750
# test generalized sinkhorn for unbalanced OT
5851
n = 100
@@ -66,55 +59,27 @@ def test_unbalanced_multiple_inputs(method):
6659

6760
M = ot.dist(x, x)
6861
epsilon = 1.
69-
mu = 1.
62+
alpha = 1.
63+
K = np.exp(- M / epsilon)
7064

71-
loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, mu=mu,
65+
loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
66+
alpha=alpha,
7267
stopThr=1e-10, method=method,
7368
log=True)
7469
# check fixed point equations
75-
# in log-domain
76-
fi = mu / (mu + epsilon)
77-
logb = np.log(b + 1e-16)
78-
loga = np.log(a + 1e-16)[:, None]
79-
logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
80-
axis=0)
81-
logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
82-
v_final = fi * (logb - logKtu)
83-
u_final = fi * (loga - logKv)
70+
fi = alpha / (alpha + epsilon)
71+
v_final = (b / K.T.dot(log["u"])) ** fi
72+
73+
u_final = (a[:, None] / K.dot(log["v"])) ** fi
8474

8575
np.testing.assert_allclose(
86-
u_final, log["logu"], atol=1e-05)
76+
u_final, log["u"], atol=1e-05)
8777
np.testing.assert_allclose(
88-
v_final, log["logv"], atol=1e-05)
78+
v_final, log["v"], atol=1e-05)
8979

9080
assert len(loss) == b.shape[1]
9181

9282

93-
def test_stabilized_vs_sinkhorn():
94-
# test if stable version matches sinkhorn
95-
n = 100
96-
97-
# Gaussian distributions
98-
a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
99-
b1 = ot.datasets.make_1D_gauss(n, m=60, s=8)
100-
b2 = ot.datasets.make_1D_gauss(n, m=30, s=4)
101-
102-
# creating matrix A containing all distributions
103-
b = np.vstack((b1, b2)).T
104-
105-
M = ot.utils.dist0(n)
106-
M /= np.median(M)
107-
epsilon = 0.1
108-
mu = 1.
109-
G, log = ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, reg=epsilon,
110-
mu=mu,
111-
log=True)
112-
G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu,
113-
method="sinkhorn", log=True)
114-
115-
np.testing.assert_allclose(G, G2)
116-
117-
11883
def test_unbalanced_barycenter():
11984
# test generalized sinkhorn for unbalanced OT barycenter
12085
n = 100
@@ -127,30 +92,27 @@ def test_unbalanced_barycenter():
12792
A = A * np.array([1, 2])[None, :]
12893
M = ot.dist(x, x)
12994
epsilon = 1.
130-
mu = 1.
95+
alpha = 1.
96+
K = np.exp(- M / epsilon)
13197

132-
q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, mu=mu,
98+
q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha,
13399
stopThr=1e-10,
134100
log=True)
135101
# check fixed point equations
136-
fi = mu / (mu + epsilon)
137-
logA = np.log(A + 1e-16)
138-
logq = np.log(q + 1e-16)[:, None]
139-
logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
140-
axis=0)
141-
logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
142-
v_final = fi * (logq - logKtu)
143-
u_final = fi * (logA - logKv)
102+
fi = alpha / (alpha + epsilon)
103+
v_final = (q[:, None] / K.T.dot(log["u"])) ** fi
104+
u_final = (A / K.dot(log["v"])) ** fi
144105

145106
np.testing.assert_allclose(
146-
u_final, log["logu"], atol=1e-05)
107+
u_final, log["u"], atol=1e-05)
147108
np.testing.assert_allclose(
148-
v_final, log["logv"], atol=1e-05)
109+
v_final, log["v"], atol=1e-05)
149110

150111

151112
def test_implemented_methods():
152-
IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
153-
TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling']
113+
IMPLEMENTED_METHODS = ['sinkhorn']
114+
TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized',
115+
'sinkhorn_epsilon_scaling']
154116
NOT_VALID_TOKENS = ['foo']
155117
# test generalized sinkhorn for unbalanced OT barycenter
156118
n = 3
@@ -164,21 +126,21 @@ def test_implemented_methods():
164126

165127
M = ot.dist(x, x)
166128
epsilon = 1.
167-
mu = 1.
129+
alpha = 1.
168130
for method in IMPLEMENTED_METHODS:
169-
ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu,
131+
ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
170132
method=method)
171-
ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu,
133+
ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
172134
method=method)
173135
with pytest.warns(UserWarning, match='not implemented'):
174136
for method in set(TO_BE_IMPLEMENTED_METHODS):
175-
ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu,
137+
ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
176138
method=method)
177-
ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu,
139+
ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
178140
method=method)
179141
with pytest.raises(ValueError):
180142
for method in set(NOT_VALID_TOKENS):
181-
ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu,
143+
ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
182144
method=method)
183-
ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu,
145+
ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
184146
method=method)

0 commit comments

Comments
 (0)