Skip to content

Commit a9b8af1

Browse files
authored
Merge pull request #89 from rtavenar/master
[MRG] EMD and Wasserstein 1D
2 parents 2364d56 + 362a7f8 commit a9b8af1

File tree

5 files changed

+445
-9
lines changed

5 files changed

+445
-9
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ The contributors to this library are:
167167
* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home)
168168
* [Vayer Titouan](https://tvayer.github.io/)
169169
* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT)
170+
* [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein)
170171

171172
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
172173

ot/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from . import unbalanced
2424

2525
# OT functions
26-
from .lp import emd, emd2
26+
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
2727
from .bregman import sinkhorn, sinkhorn2, barycenter
2828
from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced
2929
from .da import sinkhorn_lpl1_mm
@@ -33,7 +33,8 @@
3333

3434
__version__ = "0.5.1"
3535

36-
__all__ = ["emd", "emd2", "sinkhorn", "sinkhorn2", "utils", 'datasets',
36+
__all__ = ["emd", "emd2", 'emd_1d', "sinkhorn", "sinkhorn2", "utils", 'datasets',
3737
'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
38+
'emd_1d', 'emd2_1d', 'wasserstein_1d',
3839
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
3940
'sinkhorn_unbalanced', "barycenter_unbalanced"]

ot/lp/__init__.py

Lines changed: 292 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,18 @@
1010
import multiprocessing
1111

1212
import numpy as np
13+
from scipy.sparse import coo_matrix
1314

1415
from .import cvx
1516

1617
# import compiled emd
17-
from .emd_wrap import emd_c, check_result
18+
from .emd_wrap import emd_c, check_result, emd_1d_sorted
1819
from ..utils import parmap
1920
from .cvx import barycenter
2021
from ..utils import dist
2122

22-
__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx']
23+
__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
24+
'emd_1d', 'emd2_1d', 'wasserstein_1d']
2325

2426

2527
def emd(a, b, M, numItermax=100000, log=False):
@@ -94,7 +96,7 @@ def emd(a, b, M, numItermax=100000, log=False):
9496
b = np.asarray(b, dtype=np.float64)
9597
M = np.asarray(M, dtype=np.float64)
9698

97-
# if empty array given then use unifor distributions
99+
# if empty array given then use uniform distributions
98100
if len(a) == 0:
99101
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
100102
if len(b) == 0:
@@ -187,7 +189,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
187189
b = np.asarray(b, dtype=np.float64)
188190
M = np.asarray(M, dtype=np.float64)
189191

190-
# if empty array given then use unifor distributions
192+
# if empty array given then use uniform distributions
191193
if len(a) == 0:
192194
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
193195
if len(b) == 0:
@@ -308,4 +310,289 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
308310
log_dict['displacement_square_norms'] = displacement_square_norms
309311
return X, log_dict
310312
else:
311-
return X
313+
return X
314+
315+
316+
def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
317+
log=False):
318+
"""Solves the Earth Movers distance problem between 1d measures and returns
319+
the OT matrix
320+
321+
322+
.. math::
323+
\gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
324+
325+
s.t. \gamma 1 = a,
326+
\gamma^T 1= b,
327+
\gamma\geq 0
328+
where :
329+
330+
- d is the metric
331+
- x_a and x_b are the samples
332+
- a and b are the sample weights
333+
334+
When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
335+
336+
Uses the algorithm detailed in [1]_
337+
338+
Parameters
339+
----------
340+
x_a : (ns,) or (ns, 1) ndarray, float64
341+
Source dirac locations (on the real line)
342+
x_b : (nt,) or (ns, 1) ndarray, float64
343+
Target dirac locations (on the real line)
344+
a : (ns,) ndarray, float64, optional
345+
Source histogram (default is uniform weight)
346+
b : (nt,) ndarray, float64, optional
347+
Target histogram (default is uniform weight)
348+
metric: str, optional (default='sqeuclidean')
349+
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
350+
Due to implementation details, this function runs faster when
351+
`'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
352+
p: float, optional (default=1.0)
353+
The p-norm to apply for if metric='minkowski'
354+
dense: boolean, optional (default=True)
355+
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
356+
Otherwise returns a sparse representation using scipy's `coo_matrix`
357+
format. Due to implementation details, this function runs faster when
358+
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
359+
are used.
360+
log: boolean, optional (default=False)
361+
If True, returns a dictionary containing the cost.
362+
Otherwise returns only the optimal transportation matrix.
363+
364+
Returns
365+
-------
366+
gamma: (ns, nt) ndarray
367+
Optimal transportation matrix for the given parameters
368+
log: dict
369+
If input log is True, a dictionary containing the cost
370+
371+
372+
Examples
373+
--------
374+
375+
Simple example with obvious solution. The function emd_1d accepts lists and
376+
performs automatic conversion to numpy arrays
377+
378+
>>> import ot
379+
>>> a=[.5, .5]
380+
>>> b=[.5, .5]
381+
>>> x_a = [2., 0.]
382+
>>> x_b = [0., 3.]
383+
>>> ot.emd_1d(x_a, x_b, a, b)
384+
array([[0. , 0.5],
385+
[0.5, 0. ]])
386+
>>> ot.emd_1d(x_a, x_b)
387+
array([[0. , 0.5],
388+
[0.5, 0. ]])
389+
390+
References
391+
----------
392+
393+
.. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
394+
Transport", 2018.
395+
396+
See Also
397+
--------
398+
ot.lp.emd : EMD for multidimensional distributions
399+
ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
400+
transportation matrix)
401+
"""
402+
a = np.asarray(a, dtype=np.float64)
403+
b = np.asarray(b, dtype=np.float64)
404+
x_a = np.asarray(x_a, dtype=np.float64)
405+
x_b = np.asarray(x_b, dtype=np.float64)
406+
407+
assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \
408+
"emd_1d should only be used with monodimensional data"
409+
assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \
410+
"emd_1d should only be used with monodimensional data"
411+
412+
# if empty array given then use uniform distributions
413+
if a.ndim == 0 or len(a) == 0:
414+
a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0]
415+
if b.ndim == 0 or len(b) == 0:
416+
b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
417+
418+
x_a_1d = x_a.reshape((-1, ))
419+
x_b_1d = x_b.reshape((-1, ))
420+
perm_a = np.argsort(x_a_1d)
421+
perm_b = np.argsort(x_b_1d)
422+
423+
G_sorted, indices, cost = emd_1d_sorted(a, b,
424+
x_a_1d[perm_a], x_b_1d[perm_b],
425+
metric=metric, p=p)
426+
G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])),
427+
shape=(a.shape[0], b.shape[0]))
428+
if dense:
429+
G = G.toarray()
430+
if log:
431+
log = {'cost': cost}
432+
return G, log
433+
return G
434+
435+
436+
def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
437+
log=False):
438+
"""Solves the Earth Movers distance problem between 1d measures and returns
439+
the loss
440+
441+
442+
.. math::
443+
\gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
444+
445+
s.t. \gamma 1 = a,
446+
\gamma^T 1= b,
447+
\gamma\geq 0
448+
where :
449+
450+
- d is the metric
451+
- x_a and x_b are the samples
452+
- a and b are the sample weights
453+
454+
When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
455+
456+
Uses the algorithm detailed in [1]_
457+
458+
Parameters
459+
----------
460+
x_a : (ns,) or (ns, 1) ndarray, float64
461+
Source dirac locations (on the real line)
462+
x_b : (nt,) or (ns, 1) ndarray, float64
463+
Target dirac locations (on the real line)
464+
a : (ns,) ndarray, float64, optional
465+
Source histogram (default is uniform weight)
466+
b : (nt,) ndarray, float64, optional
467+
Target histogram (default is uniform weight)
468+
metric: str, optional (default='sqeuclidean')
469+
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
470+
Due to implementation details, this function runs faster when
471+
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
472+
are used.
473+
p: float, optional (default=1.0)
474+
The p-norm to apply for if metric='minkowski'
475+
dense: boolean, optional (default=True)
476+
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
477+
Otherwise returns a sparse representation using scipy's `coo_matrix`
478+
format. Only used if log is set to True. Due to implementation details,
479+
this function runs faster when dense is set to False.
480+
log: boolean, optional (default=False)
481+
If True, returns a dictionary containing the transportation matrix.
482+
Otherwise returns only the loss.
483+
484+
Returns
485+
-------
486+
loss: float
487+
Cost associated to the optimal transportation
488+
log: dict
489+
If input log is True, a dictionary containing the Optimal transportation
490+
matrix for the given parameters
491+
492+
493+
Examples
494+
--------
495+
496+
Simple example with obvious solution. The function emd2_1d accepts lists and
497+
performs automatic conversion to numpy arrays
498+
499+
>>> import ot
500+
>>> a=[.5, .5]
501+
>>> b=[.5, .5]
502+
>>> x_a = [2., 0.]
503+
>>> x_b = [0., 3.]
504+
>>> ot.emd2_1d(x_a, x_b, a, b)
505+
0.5
506+
>>> ot.emd2_1d(x_a, x_b)
507+
0.5
508+
509+
References
510+
----------
511+
512+
.. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
513+
Transport", 2018.
514+
515+
See Also
516+
--------
517+
ot.lp.emd2 : EMD for multidimensional distributions
518+
ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix
519+
instead of the cost)
520+
"""
521+
# If we do not return G (log==False), then we should not to cast it to dense
522+
# (useless overhead)
523+
G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p,
524+
dense=dense and log, log=True)
525+
cost = log_emd['cost']
526+
if log:
527+
log_emd = {'G': G}
528+
return cost, log_emd
529+
return cost
530+
531+
532+
def wasserstein_1d(x_a, x_b, a=None, b=None, p=1.):
533+
"""Solves the p-Wasserstein distance problem between 1d measures and returns
534+
the distance
535+
536+
537+
.. math::
538+
\gamma = arg\min_\gamma \left( \sum_i \sum_j \gamma_{ij}
539+
|x_a[i] - x_b[j]|^p \\right)^{1/p}
540+
541+
s.t. \gamma 1 = a,
542+
\gamma^T 1= b,
543+
\gamma\geq 0
544+
where :
545+
546+
- x_a and x_b are the samples
547+
- a and b are the sample weights
548+
549+
Uses the algorithm detailed in [1]_
550+
551+
Parameters
552+
----------
553+
x_a : (ns,) or (ns, 1) ndarray, float64
554+
Source dirac locations (on the real line)
555+
x_b : (nt,) or (ns, 1) ndarray, float64
556+
Target dirac locations (on the real line)
557+
a : (ns,) ndarray, float64, optional
558+
Source histogram (default is uniform weight)
559+
b : (nt,) ndarray, float64, optional
560+
Target histogram (default is uniform weight)
561+
p: float, optional (default=1.0)
562+
The order of the p-Wasserstein distance to be computed
563+
564+
Returns
565+
-------
566+
dist: float
567+
p-Wasserstein distance
568+
569+
570+
Examples
571+
--------
572+
573+
Simple example with obvious solution. The function wasserstein_1d accepts
574+
lists and performs automatic conversion to numpy arrays
575+
576+
>>> import ot
577+
>>> a=[.5, .5]
578+
>>> b=[.5, .5]
579+
>>> x_a = [2., 0.]
580+
>>> x_b = [0., 3.]
581+
>>> ot.wasserstein_1d(x_a, x_b, a, b)
582+
0.5
583+
>>> ot.wasserstein_1d(x_a, x_b)
584+
0.5
585+
586+
References
587+
----------
588+
589+
.. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
590+
Transport", 2018.
591+
592+
See Also
593+
--------
594+
ot.lp.emd_1d : EMD for 1d distributions
595+
"""
596+
cost_emd = emd2_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p,
597+
dense=False, log=False)
598+
return np.power(cost_emd, 1. / p)

0 commit comments

Comments
 (0)