Skip to content

Commit 7c2a952

Browse files
clecozCamille Le Cozrflamary
authored
[MRG] raise error if mass mismatch in emd2 (#386)
* Two lines added in the function emd2 to ensure that the distributions have the same mass (same as it already was in the function emd). * The same mass test has been moved inside the function f(b) to be compatible with emd2 with multiple b. * Test added. The function test_emd_dimension_and_mass_mismatch (in test/test_ot.py) has been modified to check for mass mismatch with emd2. * Add PR in releases.md * Merge and add PR in releases.md * Add name in contributors.md * Correction contribution in contributors.md * Move test on mass outside of functions f(b) * Update doc of emd and emd2 Co-authored-by: Camille Le Coz <[email protected]> Co-authored-by: Rémi Flamary <[email protected]>
1 parent e547fe3 commit 7c2a952

File tree

4 files changed

+14
-0
lines changed

4 files changed

+14
-0
lines changed

CONTRIBUTORS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ The contributors to this library are:
3838
* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends)
3939
* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning)
4040
* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters)
41+
* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)
4142

4243
## Acknowledgments
4344

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
- Fixed an issue where pointers would overflow in the EMD solver, returning an
1818
incomplete transport plan above a certain size (slightly above 46k, its square being
1919
roughly 2^31) (PR #381)
20+
- Error raised when mass mismatch in emd2 (PR #386)
2021

2122

2223
## 0.8.2

ot/lp/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
230230
If this behaviour is unwanted, please make sure to provide a
231231
floating point input.
232232
233+
.. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.
234+
233235
Uses the algorithm proposed in :ref:`[1] <references-emd>`.
234236
235237
Parameters
@@ -389,6 +391,8 @@ def emd2(a, b, M, processes=1,
389391
If this behaviour is unwanted, please make sure to provide a
390392
floating point input.
391393
394+
.. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.
395+
392396
Uses the algorithm proposed in :ref:`[1] <references-emd2>`.
393397
394398
Parameters
@@ -481,6 +485,11 @@ def emd2(a, b, M, processes=1,
481485
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
482486
"Dimension mismatch, check dimensions of M with a and b"
483487

488+
# ensure that same mass
489+
np.testing.assert_almost_equal(a.sum(0),
490+
b.sum(0,keepdims=True), err_msg='a and b vector must have the same sum')
491+
b = b * a.sum(0) / b.sum(0,keepdims=True)
492+
484493
asel = a != 0
485494

486495
numThreads = check_number_threads(numThreads)

test/test_ot.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,12 @@ def test_emd_dimension_and_mass_mismatch():
2929

3030
np.testing.assert_raises(AssertionError, ot.emd2, a, a, M)
3131

32+
# test emd and emd2 for mass mismatch
33+
a = ot.utils.unif(n_samples)
3234
b = a.copy()
3335
a[0] = 100
3436
np.testing.assert_raises(AssertionError, ot.emd, a, b, M)
37+
np.testing.assert_raises(AssertionError, ot.emd2, a, b, M)
3538

3639

3740
def test_emd_backends(nx):

0 commit comments

Comments
 (0)