Skip to content

Commit 780bdfe

Browse files
Kilian FatrasKilian Fatras
authored andcommitted
fix log in sinkhorn div and add log tests
1 parent 7c02007 commit 780bdfe

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

ot/bregman.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1569,8 +1569,26 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
15691569
15701570
.. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
15711571
'''
1572+
if log:
1573+
sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
1574+
1575+
sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
1576+
1577+
sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
15721578

1573-
sinkhorn_div = (2 * empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) -
1574-
empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) -
1575-
empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs))
1576-
return max(0, sinkhorn_div)
1579+
sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
1580+
1581+
log = {}
1582+
log['sinkhorn_loss_ab'] = sinkhorn_loss_ab
1583+
log['sinkhorn_loss_a'] = sinkhorn_loss_a
1584+
log['sinkhorn_loss_b'] = sinkhorn_loss_b
1585+
log['log_sinkhorn_ab'] = log_ab
1586+
log['log_sinkhorn_a'] = log_a
1587+
log['log_sinkhorn_b'] = log_b
1588+
1589+
return max(0, sinkhorn_div), log
1590+
else:
1591+
sinkhorn_div = (empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) -
1592+
1 / 2 * empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) -
1593+
1 / 2 * empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs))
1594+
return max(0, sinkhorn_div)

test/test_bregman.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ def test_empirical_sinkhorn():
204204
G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1)
205205
sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
206206

207+
G_log, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, log=True)
208+
sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)
209+
207210
G_m = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski')
208211
sinkhorn_m = ot.sinkhorn(a, b, M_m, 1)
209212

@@ -215,6 +218,10 @@ def test_empirical_sinkhorn():
215218
sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
216219
np.testing.assert_allclose(
217220
sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian
221+
np.testing.assert_allclose(
222+
sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log
223+
np.testing.assert_allclose(
224+
sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log
218225
np.testing.assert_allclose(
219226
sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian
220227
np.testing.assert_allclose(
@@ -237,8 +244,11 @@ def test_empirical_sinkhorn_divergence():
237244
sinkhorn_div = (2 * ot.sinkhorn2(a, b, M, 1) - ot.sinkhorn2(a, a, M_s, 1) -
238245
ot.sinkhorn2(b, b, M_t, 1))
239246

247+
emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 0.1, log=True)
248+
sinkhorn_div_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)
249+
240250
# check constratints
241251
np.testing.assert_allclose(
242252
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
243253
np.testing.assert_allclose(
244-
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
254+
emp_sinkhorn_div_log, sinkhorn_div_log, atol=1e-05) # cf conv emp sinkhorn

0 commit comments

Comments
 (0)