Skip to content

Commit a8a0995

Browse files
committed
pep8 tests
1 parent 7549282 commit a8a0995

File tree

2 files changed

+61
-16
lines changed

2 files changed

+61
-16
lines changed

test/test_bregman.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
2+
3+
import ot
4+
import numpy as np
5+
6+
# import pytest
7+
8+
9+
def test_sinkhorn():
10+
# test sinkhorn
11+
n = 100
12+
np.random.seed(0)
13+
14+
x = np.random.randn(n, 2)
15+
u = ot.utils.unif(n)
16+
17+
M = ot.dist(x, x)
18+
19+
G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
20+
21+
# check constratints
22+
assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
23+
assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
24+
25+
26+
def test_sinkhorn_variants():
27+
# test sinkhorn
28+
n = 100
29+
np.random.seed(0)
30+
31+
x = np.random.randn(n, 2)
32+
u = ot.utils.unif(n)
33+
34+
M = ot.dist(x, x)
35+
36+
G0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
37+
Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10)
38+
Ges = ot.sinkhorn(
39+
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
40+
41+
# check constratints
42+
assert np.allclose(G0, Gs, atol=1e-05)
43+
assert np.allclose(G0, Ges, atol=1e-05)

test/test_ot.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def test_emd_emd2():
3636

3737
# check loss=0
3838
assert np.allclose(w, 0)
39-
39+
40+
4041
def test_emd2_multi():
4142

4243
from ot.datasets import get_1D_gauss as gauss
@@ -72,11 +73,11 @@ def test_emd2_multi():
7273
emdn = ot.emd2(a, b, M)
7374
ot.toc('multi proc : {} s')
7475

75-
assert np.allclose(emd1, emdn)
76-
77-
76+
assert np.allclose(emd1, emdn)
77+
78+
7879
def test_sinkhorn():
79-
# test sinkhorn
80+
# test sinkhorn
8081
n = 100
8182
np.random.seed(0)
8283

@@ -85,14 +86,15 @@ def test_sinkhorn():
8586

8687
M = ot.dist(x, x)
8788

88-
G = ot.sinkhorn(u, u, M,1,stopThr=1e-10)
89+
G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
8990

9091
# check constratints
91-
assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
92-
assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
93-
92+
assert np.allclose(u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
93+
assert np.allclose(u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
94+
95+
9496
def test_sinkhorn_variants():
95-
# test sinkhorn
97+
# test sinkhorn
9698
n = 100
9799
np.random.seed(0)
98100

@@ -101,11 +103,11 @@ def test_sinkhorn_variants():
101103

102104
M = ot.dist(x, x)
103105

104-
G0 = ot.sinkhorn(u, u, M,1, method='sinkhorn',stopThr=1e-10)
105-
Gs = ot.sinkhorn(u, u, M,1, method='sinkhorn_stabilized',stopThr=1e-10)
106-
Ges = ot.sinkhorn(u, u, M,1, method='sinkhorn_epsilon_scaling',stopThr=1e-10)
106+
G0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
107+
Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10)
108+
Ges = ot.sinkhorn(
109+
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
107110

108111
# check constratints
109-
assert np.allclose(G0, Gs, atol=1e-05)
110-
assert np.allclose(G0, Ges, atol=1e-05) #
111-
112+
assert np.allclose(G0, Gs, atol=1e-05)
113+
assert np.allclose(G0, Ges, atol=1e-05)

0 commit comments

Comments
 (0)