Skip to content

Commit 058d275

Browse files
[MRG] Fix warning bug in sinkhorn2 (#417)
* Pass warn argument downstream in sinkhorn2 method. * releases.md * Fix unittest. Co-authored-by: Rémi Flamary <[email protected]>
1 parent c9578b4 commit 058d275

File tree

3 files changed

+19
-8
lines changed

3 files changed

+19
-8
lines changed

RELEASES.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ roughly 2^31) (PR #381)
3030
- Fixed an issue where the doc could not be built due to some changes in matplotlib's API (Issue #403, PR #402)
3131
- Replaced Numpy C Compiler with Setuptools C Compiler due to deprecation issues (Issue #408, PR #409)
3232
- Fixed weak optimal transport docstring (Issue #404, PR #410)
33-
- Fixed error whith parameter `log=True`for `SinkhornLpl1Transport` (Issue #412,
33+
- Fixed error with parameter `log=True`for `SinkhornLpl1Transport` (Issue #412,
3434
PR #413)
35+
- Fixed an issue about `warn` parameter in `sinkhorn2` (PR #417)
3536
- Fix an issue where the parameter `stopThr` in `empirical_sinkhorn_divergence` was rendered useless by subcalls
3637
that explicitly specified `stopThr=1e-9` (Issue #421, PR #422).
3738
- Fixed a bug breaking an example where we would try to make an array of arrays of different shapes (Issue #424, PR #425)

ot/bregman.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -323,15 +323,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
323323
if len(b.shape) < 2:
324324
if method.lower() == 'sinkhorn':
325325
res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
326-
stopThr=stopThr, verbose=verbose, log=log,
326+
stopThr=stopThr, verbose=verbose,
327+
log=log, warn=warn,
327328
**kwargs)
328329
elif method.lower() == 'sinkhorn_log':
329330
res = sinkhorn_log(a, b, M, reg, numItermax=numItermax,
330-
stopThr=stopThr, verbose=verbose, log=log,
331+
stopThr=stopThr, verbose=verbose,
332+
log=log, warn=warn,
331333
**kwargs)
332334
elif method.lower() == 'sinkhorn_stabilized':
333335
res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
334-
stopThr=stopThr, verbose=verbose, log=log,
336+
stopThr=stopThr, verbose=verbose,
337+
log=log, warn=warn,
335338
**kwargs)
336339
else:
337340
raise ValueError("Unknown method '%s'." % method)
@@ -344,15 +347,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
344347

345348
if method.lower() == 'sinkhorn':
346349
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
347-
stopThr=stopThr, verbose=verbose, log=log,
350+
stopThr=stopThr, verbose=verbose,
351+
log=log, warn=warn,
348352
**kwargs)
349353
elif method.lower() == 'sinkhorn_log':
350354
return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
351-
stopThr=stopThr, verbose=verbose, log=log,
355+
stopThr=stopThr, verbose=verbose,
356+
log=log, warn=warn,
352357
**kwargs)
353358
elif method.lower() == 'sinkhorn_stabilized':
354359
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
355-
stopThr=stopThr, verbose=verbose, log=log,
360+
stopThr=stopThr, verbose=verbose,
361+
log=log, warn=warn,
356362
**kwargs)
357363
else:
358364
raise ValueError("Unknown method '%s'." % method)

test/test_bregman.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#
88
# License: MIT License
99

10+
import warnings
1011
from itertools import product
1112

1213
import numpy as np
@@ -58,7 +59,10 @@ def test_convergence_warning(method):
5859
with pytest.warns(UserWarning):
5960
ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1)
6061
with pytest.warns(UserWarning):
61-
ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1)
62+
ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=True)
63+
with warnings.catch_warnings():
64+
warnings.simplefilter("error")
65+
ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=False)
6266

6367

6468
def test_not_implemented_method():

0 commit comments

Comments
 (0)