Skip to content

Commit c92e595

Browse files
committed
Wasserstein defined as the cost itself (do not return transportation matrix)
1 parent bbc56e7 commit c92e595

File tree

3 files changed

+13
-122
lines changed

3 files changed

+13
-122
lines changed

ot/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from . import unbalanced
2424

2525
# OT functions
26-
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d, wasserstein2_1d
26+
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
2727
from .bregman import sinkhorn, sinkhorn2, barycenter
2828
from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced
2929
from .da import sinkhorn_lpl1_mm
@@ -35,6 +35,6 @@
3535

3636
__all__ = ["emd", "emd2", 'emd_1d', "sinkhorn", "sinkhorn2", "utils", 'datasets',
3737
'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
38-
'emd_1d', 'emd2_1d', 'wasserstein_1d', 'wasserstein2_1d',
38+
'emd_1d', 'emd2_1d', 'wasserstein_1d',
3939
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
4040
'sinkhorn_unbalanced', "barycenter_unbalanced"]

ot/lp/__init__.py

Lines changed: 10 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ..utils import dist
2222

2323
__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
24-
'emd_1d', 'emd2_1d', 'wasserstein_1d', 'wasserstein2_1d']
24+
'emd_1d', 'emd2_1d', 'wasserstein_1d']
2525

2626

2727
def emd(a, b, M, numItermax=100000, log=False):
@@ -529,9 +529,9 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
529529
return cost
530530

531531

532-
def wasserstein_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False):
532+
def wasserstein_1d(x_a, x_b, a=None, b=None, p=1.):
533533
"""Solves the p-Wasserstein distance problem between 1d measures and returns
534-
the OT matrix
534+
the distance
535535
536536
537537
.. math::
@@ -560,22 +560,11 @@ def wasserstein_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False):
560560
Target histogram (default is uniform weight)
561561
p: float, optional (default=1.0)
562562
The order of the p-Wasserstein distance to be computed
563-
dense: boolean, optional (default=True)
564-
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
565-
Otherwise returns a sparse representation using scipy's `coo_matrix`
566-
format. Due to implementation details, this function runs faster when
567-
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
568-
are used.
569-
log: boolean, optional (default=False)
570-
If True, returns a dictionary containing the cost.
571-
Otherwise returns only the optimal transportation matrix.
572563
573564
Returns
574565
-------
575-
gamma: (ns, nt) ndarray
576-
Optimal transportation matrix for the given parameters
577-
log: dict
578-
If input log is True, a dictionary containing the cost
566+
dist: float
567+
p-Wasserstein distance
579568
580569
581570
Examples
@@ -590,96 +579,8 @@ def wasserstein_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False):
590579
>>> x_a = [2., 0.]
591580
>>> x_b = [0., 3.]
592581
>>> ot.wasserstein_1d(x_a, x_b, a, b)
593-
array([[0. , 0.5],
594-
[0.5, 0. ]])
595-
>>> ot.wasserstein_1d(x_a, x_b)
596-
array([[0. , 0.5],
597-
[0.5, 0. ]])
598-
599-
References
600-
----------
601-
602-
.. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
603-
Transport", 2018.
604-
605-
See Also
606-
--------
607-
ot.lp.emd_1d : EMD for 1d distributions
608-
ot.lp.wasserstein2_1d : Wasserstein for 1d distributions (returns the cost
609-
instead of the transportation matrix)
610-
"""
611-
if log:
612-
G, log = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p,
613-
dense=dense, log=log)
614-
log['cost'] = np.power(log['cost'], 1. / p)
615-
return G, log
616-
return emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p,
617-
dense=dense, log=log)
618-
619-
620-
def wasserstein2_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False):
621-
"""Solves the p-Wasserstein distance problem between 1d measures and returns
622-
the loss
623-
624-
625-
.. math::
626-
\gamma = arg\min_\gamma \left( \sum_i \sum_j \gamma_{ij}
627-
|x_a[i] - x_b[j]|^p \\right)^{1/p}
628-
629-
s.t. \gamma 1 = a,
630-
\gamma^T 1= b,
631-
\gamma\geq 0
632-
where :
633-
634-
- x_a and x_b are the samples
635-
- a and b are the sample weights
636-
637-
Uses the algorithm detailed in [1]_
638-
639-
Parameters
640-
----------
641-
x_a : (ns,) or (ns, 1) ndarray, float64
642-
Source dirac locations (on the real line)
643-
x_b : (nt,) or (ns, 1) ndarray, float64
644-
Target dirac locations (on the real line)
645-
a : (ns,) ndarray, float64, optional
646-
Source histogram (default is uniform weight)
647-
b : (nt,) ndarray, float64, optional
648-
Target histogram (default is uniform weight)
649-
p: float, optional (default=1.0)
650-
The order of the p-Wasserstein distance to be computed
651-
dense: boolean, optional (default=True)
652-
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
653-
Otherwise returns a sparse representation using scipy's `coo_matrix`
654-
format. Only used if log is set to True. Due to implementation details,
655-
this function runs faster when dense is set to False.
656-
log: boolean, optional (default=False)
657-
If True, returns a dictionary containing the transportation matrix.
658-
Otherwise returns only the loss.
659-
660-
Returns
661-
-------
662-
loss: float
663-
Cost associated to the optimal transportation
664-
log: dict
665-
If input log is True, a dictionary containing the Optimal transportation
666-
matrix for the given parameters
667-
668-
669-
Examples
670-
--------
671-
672-
Simple example with obvious solution. The function wasserstein2_1d accepts
673-
lists and performs automatic conversion to numpy arrays
674-
675-
>>> import ot
676-
>>> a=[.5, .5]
677-
>>> b=[.5, .5]
678-
>>> x_a = [2., 0.]
679-
>>> x_b = [0., 3.]
680-
>>> ot.wasserstein2_1d(x_a, x_b, a, b)
681582
0.5
682-
>>> ot.wasserstein2_1d(x_a, x_b)
583+
>>> ot.wasserstein_1d(x_a, x_b)
683584
0.5
684585
685586
References
@@ -690,14 +591,8 @@ def wasserstein2_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False):
690591
691592
See Also
692593
--------
693-
ot.lp.emd2_1d : EMD for 1d distributions
694-
ot.lp.wasserstein_1d : Wasserstein for 1d distributions (returns the
695-
transportation matrix instead of the cost)
594+
ot.lp.emd_1d : EMD for 1d distributions
696595
"""
697-
if log:
698-
cost, log = emd2_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p,
699-
dense=dense, log=log)
700-
cost = np.power(cost, 1. / p)
701-
return cost, log
702-
return np.power(emd2_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p,
703-
dense=dense, log=log), 1. / p)
596+
cost_emd = emd2_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p,
597+
dense=False, log=False)
598+
return np.power(cost_emd, 1. / p)

test/test_ot.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,11 @@ def test_wass_1d():
9898
G, log = ot.emd([], [], M, log=True)
9999
wass = log["cost"]
100100

101-
G_1d, log = ot.wasserstein_1d(u, v, [], [], p=2., log=True)
102-
wass1d = log["cost"]
101+
wass1d = ot.wasserstein_1d(u, v, [], [], p=2.)
103102

104103
# check loss is similar
105104
np.testing.assert_allclose(np.sqrt(wass), wass1d)
106105

107-
# check G is similar
108-
np.testing.assert_allclose(G, G_1d)
109-
110106

111107
def test_emd_empty():
112108
# test emd and emd2 for simple identity

0 commit comments

Comments
 (0)