Skip to content

Commit 7549282

Browse files
committed
add test sinkhorn
1 parent ff104a6 commit 7549282

File tree

3 files changed

+45
-6
lines changed

3 files changed

+45
-6
lines changed

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ pep8 :
3939

4040
test : FORCE pep8
4141
python -m py.test -v test/
42+
43+
pytest : FORCE
44+
python -m py.test -v test/
4245

4346
uploadpypi :
4447
#python setup.py register

ot/gpu/bregman.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
1111
log=False, returnAsGPU=False):
12-
"""
12+
r"""
1313
Solve the entropic regularization optimal transport problem on GPU
1414
1515
The function solves the following optimization problem:

test/test_ot.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ def test_doctest():
1818

1919

2020
def test_emd_emd2():
21-
# test emd
21+
# test emd and emd2 for simple identity
2222
n = 100
23+
np.random.seed(0)
2324

2425
x = np.random.randn(n, 2)
2526
u = ot.utils.unif(n)
@@ -35,14 +36,13 @@ def test_emd_emd2():
3536

3637
# check loss=0
3738
assert np.allclose(w, 0)
38-
39-
40-
#@pytest.mark.skip(reason="Seems to be a conflict between pytest and multiprocessing")
39+
4140
def test_emd2_multi():
4241

4342
from ot.datasets import get_1D_gauss as gauss
4443

4544
n = 1000 # nb bins
45+
np.random.seed(0)
4646

4747
# bin positions
4848
x = np.arange(n, dtype=np.float64)
@@ -72,4 +72,40 @@ def test_emd2_multi():
7272
emdn = ot.emd2(a, b, M)
7373
ot.toc('multi proc : {} s')
7474

75-
assert np.allclose(emd1, emdn)
75+
assert np.allclose(emd1, emdn)
76+
77+
78+
def test_sinkhorn():
79+
# test sinkhorn
80+
n = 100
81+
np.random.seed(0)
82+
83+
x = np.random.randn(n, 2)
84+
u = ot.utils.unif(n)
85+
86+
M = ot.dist(x, x)
87+
88+
G = ot.sinkhorn(u, u, M,1,stopThr=1e-10)
89+
90+
# check constratints
91+
assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
92+
assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
93+
94+
def test_sinkhorn_variants():
95+
# test sinkhorn
96+
n = 100
97+
np.random.seed(0)
98+
99+
x = np.random.randn(n, 2)
100+
u = ot.utils.unif(n)
101+
102+
M = ot.dist(x, x)
103+
104+
G0 = ot.sinkhorn(u, u, M,1, method='sinkhorn',stopThr=1e-10)
105+
Gs = ot.sinkhorn(u, u, M,1, method='sinkhorn_stabilized',stopThr=1e-10)
106+
Ges = ot.sinkhorn(u, u, M,1, method='sinkhorn_epsilon_scaling',stopThr=1e-10)
107+
108+
# check constratints
109+
assert np.allclose(G0, Gs, atol=1e-05)
110+
assert np.allclose(G0, Ges, atol=1e-05) #
111+

0 commit comments

Comments
 (0)