Skip to content

Commit d399f62

Browse files
authored
Merge pull request #139 from AdrienCorenflos/master
[MRG] Fix ordering
2 parents fa06bb3 + a9e6950 commit d399f62

File tree

2 files changed

+51
-18
lines changed

2 files changed

+51
-18
lines changed

ot/lp/__init__.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212

1313
import multiprocessing
1414
import sys
15+
1516
import numpy as np
1617
from scipy.sparse import coo_matrix
1718

18-
from .import cvx
19-
19+
from . import cvx
20+
from .cvx import barycenter
2021
# import compiled emd
2122
from .emd_wrap import emd_c, check_result, emd_1d_sorted
22-
from ..utils import parmap
23-
from .cvx import barycenter
2423
from ..utils import dist
24+
from ..utils import parmap
2525

2626
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
2727
'emd_1d', 'emd2_1d', 'wasserstein_1d']
@@ -458,7 +458,8 @@ def f(b):
458458
return res
459459

460460

461-
def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None):
461+
def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100,
462+
stopThr=1e-7, verbose=False, log=None):
462463
"""
463464
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
464465
@@ -525,8 +526,8 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
525526

526527
T_sum = np.zeros((k, d))
527528

528-
for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()):
529-
529+
for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights,
530+
weights.tolist()):
530531
M_i = dist(X, measure_locations_i)
531532
T_i = emd(b, measure_weights_i, M_i)
532533
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
@@ -651,12 +652,12 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
651652
if b.ndim == 0 or len(b) == 0:
652653
b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
653654

654-
x_a_1d = x_a.reshape((-1, ))
655-
x_b_1d = x_b.reshape((-1, ))
655+
x_a_1d = x_a.reshape((-1,))
656+
x_b_1d = x_b.reshape((-1,))
656657
perm_a = np.argsort(x_a_1d)
657658
perm_b = np.argsort(x_b_1d)
658659

659-
G_sorted, indices, cost = emd_1d_sorted(a, b,
660+
G_sorted, indices, cost = emd_1d_sorted(a[perm_a], b[perm_b],
660661
x_a_1d[perm_a], x_b_1d[perm_b],
661662
metric=metric, p=p)
662663
G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])),

test/test_ot.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
import warnings
88

99
import numpy as np
10+
import pytest
1011
from scipy.stats import wasserstein_distance
1112

1213
import ot
1314
from ot.datasets import make_1D_gauss as gauss
14-
import pytest
1515

1616

1717
def test_emd_dimension_mismatch():
@@ -75,12 +75,12 @@ def test_emd_1d_emd2_1d():
7575
np.testing.assert_allclose(wass, wass1d_emd2)
7676

7777
# check loss is similar to scipy's implementation for Euclidean metric
78-
wass_sp = wasserstein_distance(u.reshape((-1, )), v.reshape((-1, )))
78+
wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
7979
np.testing.assert_allclose(wass_sp, wass1d_euc)
8080

8181
# check constraints
82-
np.testing.assert_allclose(np.ones((n, )) / n, G.sum(1))
83-
np.testing.assert_allclose(np.ones((m, )) / m, G.sum(0))
82+
np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1))
83+
np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))
8484

8585
# check G is similar
8686
np.testing.assert_allclose(G, G_1d)
@@ -92,6 +92,42 @@ def test_emd_1d_emd2_1d():
9292
ot.emd_1d(u, v, [], [])
9393

9494

95+
def test_emd_1d_emd2_1d_with_weights():
96+
# test emd1d gives similar results as emd
97+
n = 20
98+
m = 30
99+
rng = np.random.RandomState(0)
100+
u = rng.randn(n, 1)
101+
v = rng.randn(m, 1)
102+
103+
w_u = rng.uniform(0., 1., n)
104+
w_u = w_u / w_u.sum()
105+
106+
w_v = rng.uniform(0., 1., m)
107+
w_v = w_v / w_v.sum()
108+
109+
M = ot.dist(u, v, metric='sqeuclidean')
110+
111+
G, log = ot.emd(w_u, w_v, M, log=True)
112+
wass = log["cost"]
113+
G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True)
114+
wass1d = log["cost"]
115+
wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False)
116+
wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False)
117+
118+
# check loss is similar
119+
np.testing.assert_allclose(wass, wass1d)
120+
np.testing.assert_allclose(wass, wass1d_emd2)
121+
122+
# check loss is similar to scipy's implementation for Euclidean metric
123+
wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v)
124+
np.testing.assert_allclose(wass_sp, wass1d_euc)
125+
126+
# check constraints
127+
np.testing.assert_allclose(w_u, G.sum(1))
128+
np.testing.assert_allclose(w_v, G.sum(0))
129+
130+
95131
def test_wass_1d():
96132
# test emd1d gives similar results as emd
97133
n = 20
@@ -135,7 +171,6 @@ def test_emd_empty():
135171

136172

137173
def test_emd_sparse():
138-
139174
n = 100
140175
rng = np.random.RandomState(0)
141176

@@ -211,7 +246,6 @@ def test_emd2_multi():
211246

212247

213248
def test_lp_barycenter():
214-
215249
a1 = np.array([1.0, 0, 0])[:, None]
216250
a2 = np.array([0, 0, 1.0])[:, None]
217251

@@ -228,7 +262,6 @@ def test_lp_barycenter():
228262

229263

230264
def test_free_support_barycenter():
231-
232265
measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
233266
measures_weights = [np.array([1.]), np.array([1.])]
234267

@@ -244,7 +277,6 @@ def test_free_support_barycenter():
244277

245278
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
246279
def test_lp_barycenter_cvxopt():
247-
248280
a1 = np.array([1.0, 0, 0])[:, None]
249281
a2 = np.array([0, 0, 1.0])[:, None]
250282

0 commit comments

Comments
 (0)