Skip to content

Commit 0a039eb

Browse files
committed
Made weight vectors optional to match scipy's wass1d API
1 parent 77452dd commit 0a039eb

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

ot/lp/__init__.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
313313
return X
314314

315315

316-
def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
316+
def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=False):
317317
"""Solves the Earth Movers distance problem between 1d measures and returns
318318
the OT matrix
319319
@@ -338,10 +338,10 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
338338
Source dirac locations (on the real line)
339339
x_b : (nt,) or (ns, 1) ndarray, float64
340340
Target dirac locations (on the real line)
341-
a : (ns,) ndarray, float64
342-
Source histogram (uniform weight if empty list)
343-
b : (nt,) ndarray, float64
344-
Target histogram (uniform weight if empty list)
341+
a : (ns,) ndarray, float64, optional
342+
Source histogram (default is uniform weight)
343+
b : (nt,) ndarray, float64, optional
344+
Target histogram (default is uniform weight)
345345
metric: str, optional (default='sqeuclidean')
346346
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
347347
Due to implementation details, this function runs faster when
@@ -375,6 +375,9 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
375375
>>> x_a = [2., 0.]
376376
>>> x_b = [0., 3.]
377377
>>> ot.emd_1d(x_a, x_b, a, b)
378+
array([[0. , 0.5],
379+
[0.5, 0. ]])
380+
>>> ot.emd_1d(x_a, x_b)
378381
array([[0. , 0.5],
379382
[0.5, 0. ]])
380383
@@ -401,9 +404,9 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
401404
"emd_1d should only be used with monodimensional data"
402405

403406
# if empty array given then use uniform distributions
404-
if len(a) == 0:
407+
if a.ndim == 0 or len(a) == 0:
405408
a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0]
406-
if len(b) == 0:
409+
if b.ndim == 0 or len(b) == 0:
407410
b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
408411

409412
x_a_1d = x_a.reshape((-1, ))
@@ -424,7 +427,7 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
424427
return G
425428

426429

427-
def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
430+
def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=False):
428431
"""Solves the Earth Movers distance problem between 1d measures and returns
429432
the loss
430433
@@ -449,10 +452,10 @@ def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
449452
Source dirac locations (on the real line)
450453
x_b : (nt,) or (ns, 1) ndarray, float64
451454
Target dirac locations (on the real line)
452-
a : (ns,) ndarray, float64
453-
Source histogram (uniform weight if empty list)
454-
b : (nt,) ndarray, float64
455-
Target histogram (uniform weight if empty list)
455+
a : (ns,) ndarray, float64, optional
456+
Source histogram (default is uniform weight)
457+
b : (nt,) ndarray, float64, optional
458+
Target histogram (default is uniform weight)
456459
metric: str, optional (default='sqeuclidean')
457460
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
458461
Due to implementation details, this function runs faster when
@@ -488,6 +491,8 @@ def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
488491
>>> x_b = [0., 3.]
489492
>>> ot.emd2_1d(x_a, x_b, a, b)
490493
0.5
494+
>>> ot.emd2_1d(x_a, x_b)
495+
0.5
491496
492497
References
493498
----------

0 commit comments

Comments
 (0)