Skip to content

Commit 5e7bfbc

Browse files
committed
working test +92 percent tets coverege
1 parent 75e7802 commit 5e7bfbc

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

test/test_bregman.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_bary():
105105
ot.bregman.barycenter(A, M, reg, log=True, verbose=True)
106106

107107

108-
def test_wassersteinbary():
108+
def test_wasserstein_bary_2d():
109109

110110
size = 100 # size of a square image
111111
a1 = np.random.randn(size, size)

test/test_gpu.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import numpy as np
88
import ot
9-
import time
109
import pytest
1110

1211
try: # test if cudamat installed
@@ -31,7 +30,11 @@ def test_gpu_dist():
3130

3231
np.testing.assert_allclose(M, M2, rtol=1e-10)
3332

34-
M2 = ot.gpu.dist(a.copy(), b.copy(), to_numpy=False)
33+
M2 = ot.gpu.dist(a.copy(), b.copy(), metric='euclidean', to_numpy=False)
34+
35+
# check raise not implemented wrong metric
36+
with pytest.raises(NotImplementedError):
37+
M2 = ot.gpu.dist(a.copy(), b.copy(), metric='cityblock', to_numpy=False)
3538

3639

3740
@pytest.mark.skipif(nogpu, reason="No GPU available")
@@ -46,6 +49,9 @@ def test_gpu_sinkhorn():
4649
wa = ot.unif(n_samples // 4)
4750
wb = ot.unif(n_samples)
4851

52+
wb2 = np.random.rand(n_samples, 20)
53+
wb2 /= wb2.sum(0, keepdims=True)
54+
4955
M = ot.dist(a.copy(), b.copy())
5056
M2 = ot.gpu.dist(a.copy(), b.copy(), to_numpy=False)
5157

@@ -56,7 +62,11 @@ def test_gpu_sinkhorn():
5662

5763
np.testing.assert_allclose(G1, G, rtol=1e-10)
5864

59-
G2 = ot.gpu.sinkhorn(wa, wb, M2, reg, to_numpy=False)
65+
# run all on gpu
66+
ot.gpu.sinkhorn(wa, wb, M2, reg, to_numpy=False, log=True)
67+
68+
# run sinkhorn for multiple targets
69+
ot.gpu.sinkhorn(wa, wb2, M2, reg, to_numpy=False, log=True)
6070

6171

6272
@pytest.mark.skipif(nogpu, reason="No GPU available")
@@ -83,4 +93,4 @@ def test_gpu_sinkhorn_lpl1():
8393

8494
np.testing.assert_allclose(G1, G, rtol=1e-10)
8595

86-
G2 = ot.gpu.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M2, reg, to_numpy=False)
96+
ot.gpu.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M2, reg, to_numpy=False, log=True)

0 commit comments

Comments
 (0)