Skip to content

Commit 782d9b1

Browse files
Kilian FatrasKilian Fatras
authored andcommitted
fix test sinkhorn div
1 parent 69186a6 commit 782d9b1

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

test/test_bregman.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -243,11 +243,11 @@ def test_empirical_sinkhorn_divergence():
243243
emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1)
244244
sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1))
245245

246-
emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 0.1, log=True)
247-
sink_div_log, log_s = ot.sinkhorn2(a, b, M, 1)
248-
sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1)
249-
sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1)
250-
sink_div_log = sink_div_log - 1 / 2 * (sink_div_log_a + sink_div_log_b)
246+
emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, log=True)
247+
sink_div_log_ab, log_s_ab = ot.sinkhorn2(a, b, M, 1, log=True)
248+
sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1, log=True)
249+
sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1, log=True)
250+
sink_div_log = sink_div_log_ab - 1 / 2 * (sink_div_log_a + sink_div_log_b)
251251

252252
# check constratints
253253
np.testing.assert_allclose(

0 commit comments

Comments
 (0)