7
7
import warnings
8
8
9
9
import numpy as np
10
+ import pytest
10
11
from scipy .stats import wasserstein_distance
11
12
12
13
import ot
13
14
from ot .datasets import make_1D_gauss as gauss
14
- import pytest
15
15
16
16
17
17
def test_emd_dimension_mismatch ():
@@ -75,12 +75,12 @@ def test_emd_1d_emd2_1d():
75
75
np .testing .assert_allclose (wass , wass1d_emd2 )
76
76
77
77
# check loss is similar to scipy's implementation for Euclidean metric
78
- wass_sp = wasserstein_distance (u .reshape ((- 1 , )), v .reshape ((- 1 , )))
78
+ wass_sp = wasserstein_distance (u .reshape ((- 1 ,)), v .reshape ((- 1 ,)))
79
79
np .testing .assert_allclose (wass_sp , wass1d_euc )
80
80
81
81
# check constraints
82
- np .testing .assert_allclose (np .ones ((n , )) / n , G .sum (1 ))
83
- np .testing .assert_allclose (np .ones ((m , )) / m , G .sum (0 ))
82
+ np .testing .assert_allclose (np .ones ((n ,)) / n , G .sum (1 ))
83
+ np .testing .assert_allclose (np .ones ((m ,)) / m , G .sum (0 ))
84
84
85
85
# check G is similar
86
86
np .testing .assert_allclose (G , G_1d )
@@ -92,6 +92,42 @@ def test_emd_1d_emd2_1d():
92
92
ot .emd_1d (u , v , [], [])
93
93
94
94
95
+ def test_emd_1d_emd2_1d_with_weights ():
96
+ # test emd1d gives similar results as emd
97
+ n = 20
98
+ m = 30
99
+ rng = np .random .RandomState (0 )
100
+ u = rng .randn (n , 1 )
101
+ v = rng .randn (m , 1 )
102
+
103
+ w_u = rng .uniform (0. , 1. , n )
104
+ w_u = w_u / w_u .sum ()
105
+
106
+ w_v = rng .uniform (0. , 1. , m )
107
+ w_v = w_v / w_v .sum ()
108
+
109
+ M = ot .dist (u , v , metric = 'sqeuclidean' )
110
+
111
+ G , log = ot .emd (w_u , w_v , M , log = True )
112
+ wass = log ["cost" ]
113
+ G_1d , log = ot .emd_1d (u , v , w_u , w_v , metric = 'sqeuclidean' , log = True )
114
+ wass1d = log ["cost" ]
115
+ wass1d_emd2 = ot .emd2_1d (u , v , w_u , w_v , metric = 'sqeuclidean' , log = False )
116
+ wass1d_euc = ot .emd2_1d (u , v , w_u , w_v , metric = 'euclidean' , log = False )
117
+
118
+ # check loss is similar
119
+ np .testing .assert_allclose (wass , wass1d )
120
+ np .testing .assert_allclose (wass , wass1d_emd2 )
121
+
122
+ # check loss is similar to scipy's implementation for Euclidean metric
123
+ wass_sp = wasserstein_distance (u .reshape ((- 1 ,)), v .reshape ((- 1 ,)), w_u , w_v )
124
+ np .testing .assert_allclose (wass_sp , wass1d_euc )
125
+
126
+ # check constraints
127
+ np .testing .assert_allclose (w_u , G .sum (1 ))
128
+ np .testing .assert_allclose (w_v , G .sum (0 ))
129
+
130
+
95
131
def test_wass_1d ():
96
132
# test emd1d gives similar results as emd
97
133
n = 20
@@ -135,7 +171,6 @@ def test_emd_empty():
135
171
136
172
137
173
def test_emd_sparse ():
138
-
139
174
n = 100
140
175
rng = np .random .RandomState (0 )
141
176
@@ -211,7 +246,6 @@ def test_emd2_multi():
211
246
212
247
213
248
def test_lp_barycenter ():
214
-
215
249
a1 = np .array ([1.0 , 0 , 0 ])[:, None ]
216
250
a2 = np .array ([0 , 0 , 1.0 ])[:, None ]
217
251
@@ -228,7 +262,6 @@ def test_lp_barycenter():
228
262
229
263
230
264
def test_free_support_barycenter ():
231
-
232
265
measures_locations = [np .array ([- 1. ]).reshape ((1 , 1 )), np .array ([1. ]).reshape ((1 , 1 ))]
233
266
measures_weights = [np .array ([1. ]), np .array ([1. ])]
234
267
@@ -244,7 +277,6 @@ def test_free_support_barycenter():
244
277
245
278
@pytest .mark .skipif (not ot .lp .cvx .cvxopt , reason = "No cvxopt available" )
246
279
def test_lp_barycenter_cvxopt ():
247
-
248
280
a1 = np .array ([1.0 , 0 , 0 ])[:, None ]
249
281
a2 = np .array ([0 , 0 , 1.0 ])[:, None ]
250
282
0 commit comments