Skip to content

Commit f2aaf40

Browse files
committed
debug sinkhorn divergence gradients
1 parent 0138dcf commit f2aaf40

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

ot/bregman.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3173,8 +3173,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
31733173
return loss
31743174

31753175
else:
3176-
M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric)
3177-
M = nx.from_numpy(M, type_as=a)
3176+
M = dist(X_s, X_t, metric=metric)
31783177

31793178
if log:
31803179
sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
@@ -3287,6 +3286,10 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
32873286
International Conference on Artficial Intelligence and Statistics,
32883287
(AISTATS) 21, 2018
32893288
'''
3289+
X_s, X_t = list_to_array(X_s, X_t)
3290+
3291+
nx = get_backend(X_s, X_t)
3292+
32903293
if log:
32913294
sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
32923295
numIterMax=numIterMax,
@@ -3313,7 +3316,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
33133316
log['log_sinkhorn_a'] = log_a
33143317
log['log_sinkhorn_b'] = log_b
33153318

3316-
return max(0, sinkhorn_div), log
3319+
return nx.maximum(0, sinkhorn_div), log
33173320

33183321
else:
33193322
sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
@@ -3332,7 +3335,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
33323335
warn=warn, **kwargs)
33333336

33343337
sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
3335-
return max(0, sinkhorn_div)
3338+
return nx.maximum(0, sinkhorn_div)
33363339

33373340

33383341
def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,

test/test_bregman.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,34 @@ def test_empirical_sinkhorn_divergence(nx):
879879
ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True)
880880

881881

882+
def test_empirical_sinkhorn_divergence_gradient():
883+
# Test sinkhorn divergence
884+
n = 10
885+
a = np.linspace(1, n, n)
886+
a /= a.sum()
887+
b = ot.unif(n)
888+
X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1))
889+
X_t = np.reshape(np.arange(0, n * 2, 2, dtype=np.float64), (n, 1))
890+
891+
nx = ot.backend.TorchBackend()
892+
893+
ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t)
894+
895+
ab.requires_grad = True
896+
bb.requires_grad = True
897+
X_sb.requires_grad = True
898+
X_tb.requires_grad = True
899+
900+
emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)
901+
902+
emp_sinkhorn_div.backward()
903+
904+
assert ab.grad is not None
905+
assert bb.grad is not None
906+
assert X_sb.grad is not None
907+
assert X_tb.grad is not None
908+
909+
882910
def test_stabilized_vs_sinkhorn_multidim(nx):
883911
# test if stable version matches sinkhorn
884912
# for multidimensional inputs

0 commit comments

Comments
 (0)