Skip to content

Commit ff104a6

Browse files
committed
add test for emd and emd2
1 parent 01f8c44 commit ff104a6

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

test/test_ot.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import ot
44
import numpy as np
55

6-
#import pytest
6+
# import pytest
77

88

99
def test_doctest():
@@ -17,8 +17,28 @@ def test_doctest():
1717
doctest.testmod(ot.bregman, verbose=True)
1818

1919

20+
def test_emd_emd2():
21+
# test emd
22+
n = 100
23+
24+
x = np.random.randn(n, 2)
25+
u = ot.utils.unif(n)
26+
27+
M = ot.dist(x, x)
28+
29+
G = ot.emd(u, u, M)
30+
31+
# check G is identity
32+
assert np.allclose(G, np.eye(n) / n)
33+
34+
w = ot.emd2(u, u, M)
35+
36+
# check loss=0
37+
assert np.allclose(w, 0)
38+
39+
2040
#@pytest.mark.skip(reason="Seems to be a conflict between pytest and multiprocessing")
21-
def test_emd_multi():
41+
def test_emd2_multi():
2242

2343
from ot.datasets import get_1D_gauss as gauss
2444

0 commit comments

Comments
 (0)