21
21
from ..utils import dist
22
22
23
23
__all__ = ['emd' , 'emd2' , 'barycenter' , 'free_support_barycenter' , 'cvx' ,
24
- 'emd_1d' , 'emd2_1d' ]
24
+ 'emd_1d' , 'emd2_1d' , 'wasserstein_1d' , 'wasserstein2_1d' ]
25
25
26
26
27
27
def emd (a , b , M , numItermax = 100000 , log = False ):
@@ -313,7 +313,8 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
313
313
return X
314
314
315
315
316
- def emd_1d (x_a , x_b , a = None , b = None , metric = 'sqeuclidean' , dense = True , log = False ):
316
+ def emd_1d (x_a , x_b , a = None , b = None , metric = 'sqeuclidean' , p = 1. , dense = True ,
317
+ log = False ):
317
318
"""Solves the Earth Movers distance problem between 1d measures and returns
318
319
the OT matrix
319
320
@@ -330,6 +331,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=False
330
331
- x_a and x_b are the samples
331
332
- a and b are the sample weights
332
333
334
+ When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
335
+
333
336
Uses the algorithm detailed in [1]_
334
337
335
338
Parameters
@@ -346,11 +349,14 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=False
346
349
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
347
350
Due to implementation details, this function runs faster when
348
351
`'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
352
+ p: float, optional (default=1.0)
353
+ The p-norm to apply for if metric='minkowski'
349
354
dense: boolean, optional (default=True)
350
355
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
351
356
Otherwise returns a sparse representation using scipy's `coo_matrix`
352
357
format. Due to implementation details, this function runs faster when
353
- dense is set to False.
358
+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
359
+ are used.
354
360
log: boolean, optional (default=False)
355
361
If True, returns a dictionary containing the cost.
356
362
Otherwise returns only the optimal transportation matrix.
@@ -416,7 +422,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=False
416
422
417
423
G_sorted , indices , cost = emd_1d_sorted (a , b ,
418
424
x_a_1d [perm_a ], x_b_1d [perm_b ],
419
- metric = metric )
425
+ metric = metric , p = p )
420
426
G = coo_matrix ((G_sorted , (perm_a [indices [:, 0 ]], perm_b [indices [:, 1 ]])),
421
427
shape = (a .shape [0 ], b .shape [0 ]))
422
428
if dense :
@@ -427,7 +433,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=False
427
433
return G
428
434
429
435
430
- def emd2_1d (x_a , x_b , a = None , b = None , metric = 'sqeuclidean' , dense = True , log = False ):
436
+ def emd2_1d (x_a , x_b , a = None , b = None , metric = 'sqeuclidean' , p = 1. , dense = True ,
437
+ log = False ):
431
438
"""Solves the Earth Movers distance problem between 1d measures and returns
432
439
the loss
433
440
@@ -444,6 +451,8 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=Fals
444
451
- x_a and x_b are the samples
445
452
- a and b are the sample weights
446
453
454
+ When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
455
+
447
456
Uses the algorithm detailed in [1]_
448
457
449
458
Parameters
@@ -459,7 +468,10 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=Fals
459
468
metric: str, optional (default='sqeuclidean')
460
469
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
461
470
Due to implementation details, this function runs faster when
462
- `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
471
+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
472
+ are used.
473
+ p: float, optional (default=1.0)
474
+ The p-norm to apply for if metric='minkowski'
463
475
dense: boolean, optional (default=True)
464
476
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
465
477
Otherwise returns a sparse representation using scipy's `coo_matrix`
@@ -508,10 +520,185 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=Fals
508
520
"""
509
521
# If we do not return G (log==False), then we should not to cast it to dense
510
522
# (useless overhead)
511
- G , log_emd = emd_1d (x_a = x_a , x_b = x_b , a = a , b = b , metric = metric ,
523
+ G , log_emd = emd_1d (x_a = x_a , x_b = x_b , a = a , b = b , metric = metric , p = p ,
512
524
dense = dense and log , log = True )
513
525
cost = log_emd ['cost' ]
514
526
if log :
515
527
log_emd = {'G' : G }
516
528
return cost , log_emd
517
- return cost
529
+ return cost
530
+
531
+
532
+ def wasserstein_1d (x_a , x_b , a = None , b = None , p = 1. , dense = True , log = False ):
533
+ """Solves the Wasserstein distance problem between 1d measures and returns
534
+ the OT matrix
535
+
536
+
537
+ .. math::
538
+ \gamma = arg\min_\gamma \left(\sum_i \sum_j \gamma_{ij}
539
+ |x_a[i] - x_b[j]|^p \r ight)^{1/p}
540
+
541
+ s.t. \gamma 1 = a,
542
+ \gamma^T 1= b,
543
+ \gamma\geq 0
544
+ where :
545
+
546
+ - x_a and x_b are the samples
547
+ - a and b are the sample weights
548
+
549
+ Uses the algorithm detailed in [1]_
550
+
551
+ Parameters
552
+ ----------
553
+ x_a : (ns,) or (ns, 1) ndarray, float64
554
+ Source dirac locations (on the real line)
555
+ x_b : (nt,) or (ns, 1) ndarray, float64
556
+ Target dirac locations (on the real line)
557
+ a : (ns,) ndarray, float64, optional
558
+ Source histogram (default is uniform weight)
559
+ b : (nt,) ndarray, float64, optional
560
+ Target histogram (default is uniform weight)
561
+ p: float, optional (default=1.0)
562
+ The order of the p-Wasserstein distance to be computed
563
+ dense: boolean, optional (default=True)
564
+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
565
+ Otherwise returns a sparse representation using scipy's `coo_matrix`
566
+ format. Due to implementation details, this function runs faster when
567
+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
568
+ are used.
569
+ log: boolean, optional (default=False)
570
+ If True, returns a dictionary containing the cost.
571
+ Otherwise returns only the optimal transportation matrix.
572
+
573
+ Returns
574
+ -------
575
+ gamma: (ns, nt) ndarray
576
+ Optimal transportation matrix for the given parameters
577
+ log: dict
578
+ If input log is True, a dictionary containing the cost
579
+
580
+
581
+ Examples
582
+ --------
583
+
584
+ Simple example with obvious solution. The function wasserstein_1d accepts
585
+ lists and performs automatic conversion to numpy arrays
586
+
587
+ >>> import ot
588
+ >>> a=[.5, .5]
589
+ >>> b=[.5, .5]
590
+ >>> x_a = [2., 0.]
591
+ >>> x_b = [0., 3.]
592
+ >>> ot.wasserstein_1d(x_a, x_b, a, b)
593
+ array([[0. , 0.5],
594
+ [0.5, 0. ]])
595
+ >>> ot.wasserstein_1d(x_a, x_b)
596
+ array([[0. , 0.5],
597
+ [0.5, 0. ]])
598
+
599
+ References
600
+ ----------
601
+
602
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
603
+ Transport", 2018.
604
+
605
+ See Also
606
+ --------
607
+ ot.lp.emd_1d : EMD for 1d distributions
608
+ ot.lp.wasserstein2_1d : Wasserstein for 1d distributions (returns the cost
609
+ instead of the transportation matrix)
610
+ """
611
+ if log :
612
+ G , log = emd_1d (x_a = x_a , x_b = x_b , a = a , b = b , metric = 'minkowski' , p = p ,
613
+ dense = dense , log = log )
614
+ log ['cost' ] = np .power (log ['cost' ], 1. / p )
615
+ return G , log
616
+ return emd_1d (x_a = x_a , x_b = x_b , a = a , b = b , metric = 'minkowski' , p = p ,
617
+ dense = dense , log = log )
618
+
619
+
620
+ def wasserstein2_1d (x_a , x_b , a = None , b = None , metric = 'sqeuclidean' , p = 1. ,
621
+ dense = True , log = False ):
622
+ """Solves the Wasserstein distance problem between 1d measures and returns
623
+ the loss
624
+
625
+
626
+ .. math::
627
+ \gamma = arg\min_\gamma \left( \sum_i \sum_j \gamma_{ij}
628
+ |x_a[i] - x_b[j]|^p \r ight)^{1/p}
629
+
630
+ s.t. \gamma 1 = a,
631
+ \gamma^T 1= b,
632
+ \gamma\geq 0
633
+ where :
634
+
635
+ - x_a and x_b are the samples
636
+ - a and b are the sample weights
637
+
638
+ Uses the algorithm detailed in [1]_
639
+
640
+ Parameters
641
+ ----------
642
+ x_a : (ns,) or (ns, 1) ndarray, float64
643
+ Source dirac locations (on the real line)
644
+ x_b : (nt,) or (ns, 1) ndarray, float64
645
+ Target dirac locations (on the real line)
646
+ a : (ns,) ndarray, float64, optional
647
+ Source histogram (default is uniform weight)
648
+ b : (nt,) ndarray, float64, optional
649
+ Target histogram (default is uniform weight)
650
+ p: float, optional (default=1.0)
651
+ The order of the p-Wasserstein distance to be computed
652
+ dense: boolean, optional (default=True)
653
+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
654
+ Otherwise returns a sparse representation using scipy's `coo_matrix`
655
+ format. Only used if log is set to True. Due to implementation details,
656
+ this function runs faster when dense is set to False.
657
+ log: boolean, optional (default=False)
658
+ If True, returns a dictionary containing the transportation matrix.
659
+ Otherwise returns only the loss.
660
+
661
+ Returns
662
+ -------
663
+ loss: float
664
+ Cost associated to the optimal transportation
665
+ log: dict
666
+ If input log is True, a dictionary containing the Optimal transportation
667
+ matrix for the given parameters
668
+
669
+
670
+ Examples
671
+ --------
672
+
673
+ Simple example with obvious solution. The function wasserstein2_1d accepts
674
+ lists and performs automatic conversion to numpy arrays
675
+
676
+ >>> import ot
677
+ >>> a=[.5, .5]
678
+ >>> b=[.5, .5]
679
+ >>> x_a = [2., 0.]
680
+ >>> x_b = [0., 3.]
681
+ >>> ot.wasserstein2_1d(x_a, x_b, a, b)
682
+ 0.5
683
+ >>> ot.wasserstein2_1d(x_a, x_b)
684
+ 0.5
685
+
686
+ References
687
+ ----------
688
+
689
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
690
+ Transport", 2018.
691
+
692
+ See Also
693
+ --------
694
+ ot.lp.emd2_1d : EMD for 1d distributions
695
+ ot.lp.wasserstein_1d : Wasserstein for 1d distributions (returns the
696
+ transportation matrix instead of the cost)
697
+ """
698
+ if log :
699
+ cost , log = emd2_1d (x_a = x_a , x_b = x_b , a = a , b = b , metric = 'minkowski' , p = p ,
700
+ dense = dense , log = log )
701
+ cost = np .power (cost , 1. / p )
702
+ return cost , log
703
+ return np .power (emd2_1d (x_a = x_a , x_b = x_b , a = a , b = b , metric = 'minkowski' , p = p ,
704
+ dense = dense , log = log ), 1. / p )
0 commit comments