6
6
7
7
import numpy as np
8
8
import ot
9
- import time
10
9
import pytest
11
10
12
11
try : # test if cudamat installed
@@ -31,7 +30,11 @@ def test_gpu_dist():
31
30
32
31
np .testing .assert_allclose (M , M2 , rtol = 1e-10 )
33
32
34
- M2 = ot .gpu .dist (a .copy (), b .copy (), to_numpy = False )
33
+ M2 = ot .gpu .dist (a .copy (), b .copy (), metric = 'euclidean' , to_numpy = False )
34
+
35
+ # check raise not implemented wrong metric
36
+ with pytest .raises (NotImplementedError ):
37
+ M2 = ot .gpu .dist (a .copy (), b .copy (), metric = 'cityblock' , to_numpy = False )
35
38
36
39
37
40
@pytest .mark .skipif (nogpu , reason = "No GPU available" )
@@ -46,6 +49,9 @@ def test_gpu_sinkhorn():
46
49
wa = ot .unif (n_samples // 4 )
47
50
wb = ot .unif (n_samples )
48
51
52
+ wb2 = np .random .rand (n_samples , 20 )
53
+ wb2 /= wb2 .sum (0 , keepdims = True )
54
+
49
55
M = ot .dist (a .copy (), b .copy ())
50
56
M2 = ot .gpu .dist (a .copy (), b .copy (), to_numpy = False )
51
57
@@ -56,7 +62,11 @@ def test_gpu_sinkhorn():
56
62
57
63
np .testing .assert_allclose (G1 , G , rtol = 1e-10 )
58
64
59
- G2 = ot .gpu .sinkhorn (wa , wb , M2 , reg , to_numpy = False )
65
+ # run all on gpu
66
+ ot .gpu .sinkhorn (wa , wb , M2 , reg , to_numpy = False , log = True )
67
+
68
+ # run sinkhorn for multiple targets
69
+ ot .gpu .sinkhorn (wa , wb2 , M2 , reg , to_numpy = False , log = True )
60
70
61
71
62
72
@pytest .mark .skipif (nogpu , reason = "No GPU available" )
@@ -83,4 +93,4 @@ def test_gpu_sinkhorn_lpl1():
83
93
84
94
np .testing .assert_allclose (G1 , G , rtol = 1e-10 )
85
95
86
- G2 = ot .gpu .da .sinkhorn_lpl1_mm (wa , labels_a , wb , M2 , reg , to_numpy = False )
96
+ ot .gpu .da .sinkhorn_lpl1_mm (wa , labels_a , wb , M2 , reg , to_numpy = False , log = True )
0 commit comments