Skip to content

Commit c9578b4

Browse files
tlacomberflamary
andauthored
[MRG] Fix#421 pass stopThr to the sinkhorn function in empirical_sinkhorn_divergence (#422)
* fix stopThr hardcoded in some places * added fix documentation in RELEASES.Md Co-authored-by: Rémi Flamary <[email protected]>
1 parent f827771 commit c9578b4

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

RELEASES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ roughly 2^31) (PR #381)
3232
- Fixed weak optimal transport docstring (Issue #404, PR #410)
3333
- Fixed error whith parameter `log=True`for `SinkhornLpl1Transport` (Issue #412,
3434
PR #413)
35+
- Fix an issue where the parameter `stopThr` in `empirical_sinkhorn_divergence` was rendered useless by subcalls
36+
that explicitly specified `stopThr=1e-9` (Issue #421, PR #422).
3537
- Fixed a bug breaking an example where we would try to make an array of arrays of different shapes (Issue #424, PR #425)
3638

3739

ot/bregman.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,7 +1281,7 @@ def get_reg(n): # exponential decreasing
12811281
regi = get_reg(ii)
12821282

12831283
G, logi = sinkhorn_stabilized(a, b, M, regi,
1284-
numItermax=numInnerItermax, stopThr=1e-9,
1284+
numItermax=numInnerItermax, stopThr=stopThr,
12851285
warmstart=(alpha, beta), verbose=False,
12861286
print_period=20, tau=tau, log=True)
12871287

@@ -3306,17 +3306,17 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
33063306
if log:
33073307
sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
33083308
numIterMax=numIterMax,
3309-
stopThr=1e-9, verbose=verbose,
3309+
stopThr=stopThr, verbose=verbose,
33103310
log=log, warn=warn, **kwargs)
33113311

33123312
sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
33133313
numIterMax=numIterMax,
3314-
stopThr=1e-9, verbose=verbose,
3314+
stopThr=stopThr, verbose=verbose,
33153315
log=log, warn=warn, **kwargs)
33163316

33173317
sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
33183318
numIterMax=numIterMax,
3319-
stopThr=1e-9, verbose=verbose,
3319+
stopThr=stopThr, verbose=verbose,
33203320
log=log, warn=warn, **kwargs)
33213321

33223322
sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
@@ -3333,17 +3333,17 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
33333333

33343334
else:
33353335
sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
3336-
numIterMax=numIterMax, stopThr=1e-9,
3336+
numIterMax=numIterMax, stopThr=stopThr,
33373337
verbose=verbose, log=log,
33383338
warn=warn, **kwargs)
33393339

33403340
sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
3341-
numIterMax=numIterMax, stopThr=1e-9,
3341+
numIterMax=numIterMax, stopThr=stopThr,
33423342
verbose=verbose, log=log,
33433343
warn=warn, **kwargs)
33443344

33453345
sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
3346-
numIterMax=numIterMax, stopThr=1e-9,
3346+
numIterMax=numIterMax, stopThr=stopThr,
33473347
verbose=verbose, log=log,
33483348
warn=warn, **kwargs)
33493349

0 commit comments

Comments
 (0)