Skip to content

Commit 0d333e0

Browse files
committed
Improved tests and docs for wasserstein_1d
1 parent 1140141 commit 0d333e0

File tree

4 files changed

+34
-10
lines changed

4 files changed

+34
-10
lines changed

ot/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from . import stochastic
2323

2424
# OT functions
25-
from .lp import emd, emd2, emd_1d, emd2_1d
25+
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d, wasserstein2_1d
2626
from .bregman import sinkhorn, sinkhorn2, barycenter
2727
from .da import sinkhorn_lpl1_mm
2828

@@ -32,5 +32,6 @@
3232
__version__ = "0.5.1"
3333

3434
__all__ = ["emd", "emd2", 'emd_1d', "sinkhorn", "sinkhorn2", "utils", 'datasets',
35-
'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', 'emd_1d', 'emd2_1d',
35+
'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
36+
'emd_1d', 'emd2_1d', 'wasserstein_1d', 'wasserstein2_1d',
3637
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']

ot/lp/__init__.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -530,13 +530,13 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
530530

531531

532532
def wasserstein_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False):
533-
"""Solves the Wasserstein distance problem between 1d measures and returns
533+
"""Solves the p-Wasserstein distance problem between 1d measures and returns
534534
the OT matrix
535535
536536
537537
.. math::
538-
\gamma = arg\min_\gamma \left(\sum_i \sum_j \gamma_{ij}
539-
|x_a[i] - x_b[j]|^p \right)^{1/p}
538+
\gamma = arg\min_\gamma \left( \sum_i \sum_j \gamma_{ij}
539+
|x_a[i] - x_b[j]|^p \\right)^{1/p}
540540
541541
s.t. \gamma 1 = a,
542542
\gamma^T 1= b,
@@ -617,15 +617,14 @@ def wasserstein_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False):
617617
dense=dense, log=log)
618618

619619

620-
def wasserstein2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1.,
621-
dense=True, log=False):
622-
"""Solves the Wasserstein distance problem between 1d measures and returns
620+
def wasserstein2_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False):
621+
"""Solves the p-Wasserstein distance problem between 1d measures and returns
623622
the loss
624623
625624
626625
.. math::
627626
\gamma = arg\min_\gamma \left( \sum_i \sum_j \gamma_{ij}
628-
|x_a[i] - x_b[j]|^p \right)^{1/p}
627+
|x_a[i] - x_b[j]|^p \\right)^{1/p}
629628
630629
s.t. \gamma 1 = a,
631630
\gamma^T 1= b,

ot/lp/emd_wrap.pyx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ cimport numpy as np
1313
from ..utils import dist
1414

1515
cimport cython
16+
cimport libc.math as math
1617

1718
import warnings
1819

@@ -159,7 +160,7 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
159160
elif metric == 'cityblock' or metric == 'euclidean':
160161
m_ij = abs(u[i] - v[j])
161162
elif metric == 'minkowski':
162-
m_ij = abs(u[i] - v[j]) ** p
163+
m_ij = math.pow(abs(u[i] - v[j]), p)
163164
else:
164165
m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)),
165166
metric=metric)[0, 0]

test/test_ot.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,29 @@ def test_emd_1d_emd2_1d():
8585
np.testing.assert_raises(AssertionError, ot.emd_1d, u, v, [], [])
8686

8787

88+
def test_wass_1d():
89+
# test emd1d gives similar results as emd
90+
n = 20
91+
m = 30
92+
rng = np.random.RandomState(0)
93+
u = rng.randn(n, 1)
94+
v = rng.randn(m, 1)
95+
96+
M = ot.dist(u, v, metric='sqeuclidean')
97+
98+
G, log = ot.emd([], [], M, log=True)
99+
wass = log["cost"]
100+
101+
G_1d, log = ot.wasserstein_1d(u, v, [], [], p=2., log=True)
102+
wass1d = log["cost"]
103+
104+
# check loss is similar
105+
np.testing.assert_allclose(np.sqrt(wass), wass1d)
106+
107+
# check G is similar
108+
np.testing.assert_allclose(G, G_1d)
109+
110+
88111
def test_emd_empty():
89112
# test emd and emd2 for simple identity
90113
n = 100

0 commit comments

Comments
 (0)