Skip to content

Commit 69186a6

Browse files
Kilian FatrasKilian Fatras
authored andcommitted
fix test sinkhorn div
1 parent 780bdfe commit 69186a6

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

ot/bregman.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,8 +1587,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
15871587
log['log_sinkhorn_b'] = log_b
15881588

15891589
return max(0, sinkhorn_div), log
1590+
15901591
else:
1591-
sinkhorn_div = (empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) -
1592-
1 / 2 * empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) -
1593-
1 / 2 * empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs))
1592+
sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
1593+
1594+
sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
1595+
1596+
sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
1597+
1598+
sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
15941599
return max(0, sinkhorn_div)

test/test_bregman.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,13 @@ def test_empirical_sinkhorn_divergence():
241241
M_t = ot.dist(X_t, X_t)
242242

243243
emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1)
244-
sinkhorn_div = (2 * ot.sinkhorn2(a, b, M, 1) - ot.sinkhorn2(a, a, M_s, 1) -
245-
ot.sinkhorn2(b, b, M_t, 1))
244+
sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1))
246245

247246
emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 0.1, log=True)
248-
sinkhorn_div_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)
247+
sink_div_log, log_s = ot.sinkhorn2(a, b, M, 1)
248+
sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1)
249+
sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1)
250+
sink_div_log = sink_div_log - 1 / 2 * (sink_div_log_a + sink_div_log_b)
249251

250252
# check constratints
251253
np.testing.assert_allclose(

0 commit comments

Comments
 (0)