@@ -313,7 +313,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
313
313
return X
314
314
315
315
316
- def emd_1d (x_a , x_b , a , b , metric = 'sqeuclidean' , dense = True , log = False ):
316
+ def emd_1d (x_a , x_b , a = None , b = None , metric = 'sqeuclidean' , dense = True , log = False ):
317
317
"""Solves the Earth Movers distance problem between 1d measures and returns
318
318
the OT matrix
319
319
@@ -338,10 +338,10 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
338
338
Source dirac locations (on the real line)
339
339
x_b : (nt,) or (ns, 1) ndarray, float64
340
340
Target dirac locations (on the real line)
341
- a : (ns,) ndarray, float64
342
- Source histogram (uniform weight if empty list )
343
- b : (nt,) ndarray, float64
344
- Target histogram (uniform weight if empty list )
341
+ a : (ns,) ndarray, float64, optional
342
+ Source histogram (default is uniform weight )
343
+ b : (nt,) ndarray, float64, optional
344
+ Target histogram (default is uniform weight )
345
345
metric: str, optional (default='sqeuclidean')
346
346
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
347
347
Due to implementation details, this function runs faster when
@@ -375,6 +375,9 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
375
375
>>> x_a = [2., 0.]
376
376
>>> x_b = [0., 3.]
377
377
>>> ot.emd_1d(x_a, x_b, a, b)
378
+ array([[0. , 0.5],
379
+ [0.5, 0. ]])
380
+ >>> ot.emd_1d(x_a, x_b)
378
381
array([[0. , 0.5],
379
382
[0.5, 0. ]])
380
383
@@ -401,9 +404,9 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
401
404
"emd_1d should only be used with monodimensional data"
402
405
403
406
# if empty array given then use uniform distributions
404
- if len (a ) == 0 :
407
+ if a . ndim == 0 or len (a ) == 0 :
405
408
a = np .ones ((x_a .shape [0 ],), dtype = np .float64 ) / x_a .shape [0 ]
406
- if len (b ) == 0 :
409
+ if b . ndim == 0 or len (b ) == 0 :
407
410
b = np .ones ((x_b .shape [0 ],), dtype = np .float64 ) / x_b .shape [0 ]
408
411
409
412
x_a_1d = x_a .reshape ((- 1 , ))
@@ -424,7 +427,7 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
424
427
return G
425
428
426
429
427
- def emd2_1d (x_a , x_b , a , b , metric = 'sqeuclidean' , dense = True , log = False ):
430
+ def emd2_1d (x_a , x_b , a = None , b = None , metric = 'sqeuclidean' , dense = True , log = False ):
428
431
"""Solves the Earth Movers distance problem between 1d measures and returns
429
432
the loss
430
433
@@ -449,10 +452,10 @@ def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
449
452
Source dirac locations (on the real line)
450
453
x_b : (nt,) or (ns, 1) ndarray, float64
451
454
Target dirac locations (on the real line)
452
- a : (ns,) ndarray, float64
453
- Source histogram (uniform weight if empty list )
454
- b : (nt,) ndarray, float64
455
- Target histogram (uniform weight if empty list )
455
+ a : (ns,) ndarray, float64, optional
456
+ Source histogram (default is uniform weight )
457
+ b : (nt,) ndarray, float64, optional
458
+ Target histogram (default is uniform weight )
456
459
metric: str, optional (default='sqeuclidean')
457
460
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
458
461
Due to implementation details, this function runs faster when
@@ -488,6 +491,8 @@ def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
488
491
>>> x_b = [0., 3.]
489
492
>>> ot.emd2_1d(x_a, x_b, a, b)
490
493
0.5
494
+ >>> ot.emd2_1d(x_a, x_b)
495
+ 0.5
491
496
492
497
References
493
498
----------
0 commit comments