Skip to content

Commit 1140141

Browse files
committed
Added minkowski variants and wasserstein_1d functions
1 parent 0a039eb commit 1140141

File tree

2 files changed

+203
-10
lines changed

2 files changed

+203
-10
lines changed

ot/lp/__init__.py

Lines changed: 195 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ..utils import dist
2222

2323
__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
24-
'emd_1d', 'emd2_1d']
24+
'emd_1d', 'emd2_1d', 'wasserstein_1d', 'wasserstein2_1d']
2525

2626

2727
def emd(a, b, M, numItermax=100000, log=False):
@@ -313,7 +313,8 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
313313
return X
314314

315315

316-
def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=False):
316+
def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
317+
log=False):
317318
"""Solves the Earth Movers distance problem between 1d measures and returns
318319
the OT matrix
319320
@@ -330,6 +331,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=False
330331
- x_a and x_b are the samples
331332
- a and b are the sample weights
332333
334+
When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
335+
333336
Uses the algorithm detailed in [1]_
334337
335338
Parameters
@@ -346,11 +349,14 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=False
346349
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
347350
Due to implementation details, this function runs faster when
348351
`'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
352+
p: float, optional (default=1.0)
353+
The p-norm to apply for if metric='minkowski'
349354
dense: boolean, optional (default=True)
350355
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
351356
Otherwise returns a sparse representation using scipy's `coo_matrix`
352357
format. Due to implementation details, this function runs faster when
353-
dense is set to False.
358+
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
359+
are used.
354360
log: boolean, optional (default=False)
355361
If True, returns a dictionary containing the cost.
356362
Otherwise returns only the optimal transportation matrix.
@@ -416,7 +422,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=False
416422

417423
G_sorted, indices, cost = emd_1d_sorted(a, b,
418424
x_a_1d[perm_a], x_b_1d[perm_b],
419-
metric=metric)
425+
metric=metric, p=p)
420426
G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])),
421427
shape=(a.shape[0], b.shape[0]))
422428
if dense:
@@ -427,7 +433,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=False
427433
return G
428434

429435

430-
def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=False):
436+
def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
437+
log=False):
431438
"""Solves the Earth Movers distance problem between 1d measures and returns
432439
the loss
433440
@@ -444,6 +451,8 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=Fals
444451
- x_a and x_b are the samples
445452
- a and b are the sample weights
446453
454+
When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
455+
447456
Uses the algorithm detailed in [1]_
448457
449458
Parameters
@@ -459,7 +468,10 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=Fals
459468
metric: str, optional (default='sqeuclidean')
460469
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
461470
Due to implementation details, this function runs faster when
462-
`'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
471+
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
472+
are used.
473+
p: float, optional (default=1.0)
474+
The p-norm to apply for if metric='minkowski'
463475
dense: boolean, optional (default=True)
464476
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
465477
Otherwise returns a sparse representation using scipy's `coo_matrix`
@@ -508,10 +520,185 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=Fals
508520
"""
509521
# If we do not return G (log==False), then we should not to cast it to dense
510522
# (useless overhead)
511-
G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric,
523+
G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p,
512524
dense=dense and log, log=True)
513525
cost = log_emd['cost']
514526
if log:
515527
log_emd = {'G': G}
516528
return cost, log_emd
517-
return cost
529+
return cost
530+
531+
532+
def wasserstein_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False):
533+
"""Solves the Wasserstein distance problem between 1d measures and returns
534+
the OT matrix
535+
536+
537+
.. math::
538+
\gamma = arg\min_\gamma \left(\sum_i \sum_j \gamma_{ij}
539+
|x_a[i] - x_b[j]|^p \right)^{1/p}
540+
541+
s.t. \gamma 1 = a,
542+
\gamma^T 1= b,
543+
\gamma\geq 0
544+
where :
545+
546+
- x_a and x_b are the samples
547+
- a and b are the sample weights
548+
549+
Uses the algorithm detailed in [1]_
550+
551+
Parameters
552+
----------
553+
x_a : (ns,) or (ns, 1) ndarray, float64
554+
Source dirac locations (on the real line)
555+
x_b : (nt,) or (ns, 1) ndarray, float64
556+
Target dirac locations (on the real line)
557+
a : (ns,) ndarray, float64, optional
558+
Source histogram (default is uniform weight)
559+
b : (nt,) ndarray, float64, optional
560+
Target histogram (default is uniform weight)
561+
p: float, optional (default=1.0)
562+
The order of the p-Wasserstein distance to be computed
563+
dense: boolean, optional (default=True)
564+
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
565+
Otherwise returns a sparse representation using scipy's `coo_matrix`
566+
format. Due to implementation details, this function runs faster when
567+
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
568+
are used.
569+
log: boolean, optional (default=False)
570+
If True, returns a dictionary containing the cost.
571+
Otherwise returns only the optimal transportation matrix.
572+
573+
Returns
574+
-------
575+
gamma: (ns, nt) ndarray
576+
Optimal transportation matrix for the given parameters
577+
log: dict
578+
If input log is True, a dictionary containing the cost
579+
580+
581+
Examples
582+
--------
583+
584+
Simple example with obvious solution. The function wasserstein_1d accepts
585+
lists and performs automatic conversion to numpy arrays
586+
587+
>>> import ot
588+
>>> a=[.5, .5]
589+
>>> b=[.5, .5]
590+
>>> x_a = [2., 0.]
591+
>>> x_b = [0., 3.]
592+
>>> ot.wasserstein_1d(x_a, x_b, a, b)
593+
array([[0. , 0.5],
594+
[0.5, 0. ]])
595+
>>> ot.wasserstein_1d(x_a, x_b)
596+
array([[0. , 0.5],
597+
[0.5, 0. ]])
598+
599+
References
600+
----------
601+
602+
.. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
603+
Transport", 2018.
604+
605+
See Also
606+
--------
607+
ot.lp.emd_1d : EMD for 1d distributions
608+
ot.lp.wasserstein2_1d : Wasserstein for 1d distributions (returns the cost
609+
instead of the transportation matrix)
610+
"""
611+
if log:
612+
G, log = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p,
613+
dense=dense, log=log)
614+
log['cost'] = np.power(log['cost'], 1. / p)
615+
return G, log
616+
return emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p,
617+
dense=dense, log=log)
618+
619+
620+
def wasserstein2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1.,
621+
dense=True, log=False):
622+
"""Solves the Wasserstein distance problem between 1d measures and returns
623+
the loss
624+
625+
626+
.. math::
627+
\gamma = arg\min_\gamma \left( \sum_i \sum_j \gamma_{ij}
628+
|x_a[i] - x_b[j]|^p \right)^{1/p}
629+
630+
s.t. \gamma 1 = a,
631+
\gamma^T 1= b,
632+
\gamma\geq 0
633+
where :
634+
635+
- x_a and x_b are the samples
636+
- a and b are the sample weights
637+
638+
Uses the algorithm detailed in [1]_
639+
640+
Parameters
641+
----------
642+
x_a : (ns,) or (ns, 1) ndarray, float64
643+
Source dirac locations (on the real line)
644+
x_b : (nt,) or (ns, 1) ndarray, float64
645+
Target dirac locations (on the real line)
646+
a : (ns,) ndarray, float64, optional
647+
Source histogram (default is uniform weight)
648+
b : (nt,) ndarray, float64, optional
649+
Target histogram (default is uniform weight)
650+
p: float, optional (default=1.0)
651+
The order of the p-Wasserstein distance to be computed
652+
dense: boolean, optional (default=True)
653+
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
654+
Otherwise returns a sparse representation using scipy's `coo_matrix`
655+
format. Only used if log is set to True. Due to implementation details,
656+
this function runs faster when dense is set to False.
657+
log: boolean, optional (default=False)
658+
If True, returns a dictionary containing the transportation matrix.
659+
Otherwise returns only the loss.
660+
661+
Returns
662+
-------
663+
loss: float
664+
Cost associated to the optimal transportation
665+
log: dict
666+
If input log is True, a dictionary containing the Optimal transportation
667+
matrix for the given parameters
668+
669+
670+
Examples
671+
--------
672+
673+
Simple example with obvious solution. The function wasserstein2_1d accepts
674+
lists and performs automatic conversion to numpy arrays
675+
676+
>>> import ot
677+
>>> a=[.5, .5]
678+
>>> b=[.5, .5]
679+
>>> x_a = [2., 0.]
680+
>>> x_b = [0., 3.]
681+
>>> ot.wasserstein2_1d(x_a, x_b, a, b)
682+
0.5
683+
>>> ot.wasserstein2_1d(x_a, x_b)
684+
0.5
685+
686+
References
687+
----------
688+
689+
.. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
690+
Transport", 2018.
691+
692+
See Also
693+
--------
694+
ot.lp.emd2_1d : EMD for 1d distributions
695+
ot.lp.wasserstein_1d : Wasserstein for 1d distributions (returns the
696+
transportation matrix instead of the cost)
697+
"""
698+
if log:
699+
cost, log = emd2_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p,
700+
dense=dense, log=log)
701+
cost = np.power(cost, 1. / p)
702+
return cost, log
703+
return np.power(emd2_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p,
704+
dense=dense, log=log), 1. / p)

ot/lp/emd_wrap.pyx

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
103103
np.ndarray[double, ndim=1, mode="c"] v_weights,
104104
np.ndarray[double, ndim=1, mode="c"] u,
105105
np.ndarray[double, ndim=1, mode="c"] v,
106-
str metric='sqeuclidean'):
106+
str metric='sqeuclidean',
107+
double p=1.):
107108
r"""
108109
Solves the Earth Movers distance problem between sorted 1d measures and
109110
returns the OT matrix and the associated cost
@@ -121,7 +122,10 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
121122
metric: str, optional (default='sqeuclidean')
122123
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
123124
Due to implementation details, this function runs faster when
124-
`'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
125+
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
126+
are used.
127+
p: float, optional (default=1.0)
128+
The p-norm to apply for if metric='minkowski'
125129
126130
Returns
127131
-------
@@ -154,6 +158,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
154158
m_ij = (u[i] - v[j]) ** 2
155159
elif metric == 'cityblock' or metric == 'euclidean':
156160
m_ij = abs(u[i] - v[j])
161+
elif metric == 'minkowski':
162+
m_ij = abs(u[i] - v[j]) ** p
157163
else:
158164
m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)),
159165
metric=metric)[0, 0]

0 commit comments

Comments
 (0)