Skip to content

Commit db28f4b

Browse files
SoniaMaz8rflamary
andauthored
[MRG] Sinkhorn gradient last step (#693)
* change solver * test * update test * Update ot/solvers.py Co-authored-by: Rémi Flamary <[email protected]> * update doc * add test for max_iter * fix bug on gradients * update RELEASES.md * update comment * add detach and comment * add example * add test for detach * fix example * delete unused importations in example * move example to backend * reduce n_trials for example --------- Co-authored-by: Rémi Flamary <[email protected]>
1 parent 6311e25 commit db28f4b

File tree

4 files changed

+203
-7
lines changed

4 files changed

+203
-7
lines changed

Diff for: RELEASES.md

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#### New features
66
- Implement CG solvers for partial FGW (PR #687)
7+
- Added feature `grad=last_step` for `ot.solvers.solve` (PR #693)
78

89
#### Closed issues
910
- Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)

Diff for: examples/backends/plot_Sinkhorn_gradients.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
================================================
4+
Different gradient computations for regularized optimal transport
5+
================================================
6+
7+
This example illustrates the differences in terms of computation time between the gradient options for the Sinkhorn solver.
8+
9+
"""
10+
11+
# Author: Sonia Mazelet <[email protected]>
12+
#
13+
# License: MIT License
14+
15+
# sphinx_gallery_thumbnail_number = 1
16+
17+
import matplotlib.pylab as pl
18+
import ot
19+
from ot.backend import torch
20+
21+
22+
##############################################################################
23+
# Time comparison of the Sinkhorn solver for different gradient options
24+
# -------------
25+
26+
27+
# %% parameters
28+
29+
n_trials = 10
30+
times_autodiff = torch.zeros(n_trials)
31+
times_envelope = torch.zeros(n_trials)
32+
times_last_step = torch.zeros(n_trials)
33+
34+
n_samples_s = 300
35+
n_samples_t = 300
36+
n_features = 5
37+
reg = 0.03
38+
39+
# Time required for the Sinkhorn solver and gradient computations, for different gradient options over multiple Gaussian distributions
40+
for i in range(n_trials):
41+
x = torch.rand((n_samples_s, n_features))
42+
y = torch.rand((n_samples_t, n_features))
43+
a = ot.utils.unif(n_samples_s)
44+
b = ot.utils.unif(n_samples_t)
45+
M = ot.dist(x, y)
46+
47+
a = torch.tensor(a, requires_grad=True)
48+
b = torch.tensor(b, requires_grad=True)
49+
M = M.clone().detach().requires_grad_(True)
50+
51+
# autodiff provides the gradient for all the outputs (plan, value, value_linear)
52+
ot.tic()
53+
res_autodiff = ot.solve(M, a, b, reg=reg, grad="autodiff")
54+
res_autodiff.value.backward()
55+
times_autodiff[i] = ot.toq()
56+
57+
a = a.clone().detach().requires_grad_(True)
58+
b = b.clone().detach().requires_grad_(True)
59+
M = M.clone().detach().requires_grad_(True)
60+
61+
# envelope provides the gradient for value
62+
ot.tic()
63+
res_envelope = ot.solve(M, a, b, reg=reg, grad="envelope")
64+
res_envelope.value.backward()
65+
times_envelope[i] = ot.toq()
66+
67+
a = a.clone().detach().requires_grad_(True)
68+
b = b.clone().detach().requires_grad_(True)
69+
M = M.clone().detach().requires_grad_(True)
70+
71+
# last_step provides the gradient for all the outputs, but only for the last iteration of the Sinkhorn algorithm
72+
ot.tic()
73+
res_last_step = ot.solve(M, a, b, reg=reg, grad="last_step")
74+
res_last_step.value.backward()
75+
times_last_step[i] = ot.toq()
76+
77+
pl.figure(1, figsize=(5, 3))
78+
pl.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
79+
pl.boxplot(
80+
([times_autodiff, times_envelope, times_last_step]),
81+
tick_labels=["autodiff", "envelope", "last_step"],
82+
showfliers=False,
83+
)
84+
pl.ylabel("Time (s)")
85+
pl.show()

Diff for: ot/solvers.py

+31-6
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,13 @@ def solve(
125125
verbose : bool, optional
126126
Print information in the solver, by default False
127127
grad : str, optional
128-
Type of gradient computation, either or 'autodiff' or 'envelope' used only for
128+
Type of gradient computation, either or 'autodiff', 'envelope' or 'last_step' used only for
129129
Sinkhorn solver. By default 'autodiff' provides gradients wrt all
130130
outputs (`plan, value, value_linear`) but with important memory cost.
131131
'envelope' provides gradients only for `value` and and other outputs are
132-
detached. This is useful for memory saving when only the value is needed.
132+
detached. This is useful for memory saving when only the value is needed. 'last_step' provides
133+
gradients only for the last iteration of the Sinkhorn solver, but provides gradient for both the OT plan and the objective values.
134+
'detach' does not compute the gradients for the Sinkhorn solver.
133135
134136
Returns
135137
-------
@@ -281,7 +283,6 @@ def solve(
281283
linear regression. NeurIPS.
282284
283285
"""
284-
285286
# detect backend
286287
nx = get_backend(M, a, b, c)
287288

@@ -412,7 +413,11 @@ def solve(
412413
potentials = (log["u"], log["v"])
413414

414415
elif reg_type.lower() in ["entropy", "kl"]:
415-
if grad == "envelope": # if envelope then detach the input
416+
if grad in [
417+
"envelope",
418+
"last_step",
419+
"detach",
420+
]: # if envelope, last_step or detach then detach the input
416421
M0, a0, b0 = M, a, b
417422
M, a, b = nx.detach(M, a, b)
418423

@@ -421,6 +426,12 @@ def solve(
421426
max_iter = 1000
422427
if tol is None:
423428
tol = 1e-9
429+
if grad == "last_step":
430+
if max_iter == 0:
431+
raise ValueError(
432+
"The maximum number of iterations must be greater than 0 when using grad=last_step."
433+
)
434+
max_iter = max_iter - 1
424435

425436
plan, log = sinkhorn_log(
426437
a,
@@ -433,6 +444,22 @@ def solve(
433444
verbose=verbose,
434445
)
435446

447+
potentials = (log["log_u"], log["log_v"])
448+
449+
# if last_step, compute the last step of the Sinkhorn algorithm with the non-detached inputs
450+
if grad == "last_step":
451+
loga = nx.log(a0)
452+
logb = nx.log(b0)
453+
v = logb - nx.logsumexp(-M0 / reg + potentials[0][:, None], 0)
454+
u = loga - nx.logsumexp(-M0 / reg + potentials[1][None, :], 1)
455+
plan = nx.exp(-M0 / reg + u[:, None] + v[None, :])
456+
potentials = (u, v)
457+
log["niter"] = max_iter + 1
458+
log["log_u"] = u
459+
log["log_v"] = v
460+
log["u"] = nx.exp(u)
461+
log["v"] = nx.exp(v)
462+
436463
value_linear = nx.sum(M * plan)
437464

438465
if reg_type.lower() == "entropy":
@@ -442,8 +469,6 @@ def solve(
442469
plan, a[:, None] * b[None, :]
443470
)
444471

445-
potentials = (log["log_u"], log["log_v"])
446-
447472
if grad == "envelope": # set the gradient at convergence
448473
value = nx.set_gradients(
449474
value,

Diff for: test/test_solvers.py

+86-1
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,91 @@ def test_solve(nx):
143143
sol0 = ot.solve(M, reg=1, reg_type="cryptic divergence")
144144

145145

146+
@pytest.mark.skipif(not torch, reason="torch no installed")
147+
def test_solve_last_step():
148+
n_samples_s = 10
149+
n_samples_t = 7
150+
n_features = 2
151+
rng = np.random.RandomState(0)
152+
153+
x = rng.randn(n_samples_s, n_features)
154+
y = rng.randn(n_samples_t, n_features)
155+
a = ot.utils.unif(n_samples_s)
156+
b = ot.utils.unif(n_samples_t)
157+
M = ot.dist(x, y)
158+
159+
# Check that last_step and autodiff give the same result and similar gradients
160+
a = torch.tensor(a, requires_grad=True)
161+
b = torch.tensor(b, requires_grad=True)
162+
M = torch.tensor(M, requires_grad=True)
163+
164+
sol0 = ot.solve(M, a, b, reg=10, grad="autodiff")
165+
sol0.value.backward()
166+
167+
gM0 = M.grad.clone()
168+
ga0 = a.grad.clone()
169+
gb0 = b.grad.clone()
170+
171+
a = torch.tensor(a, requires_grad=True)
172+
b = torch.tensor(b, requires_grad=True)
173+
M = torch.tensor(M, requires_grad=True)
174+
175+
sol = ot.solve(M, a, b, reg=10, grad="last_step")
176+
sol.value.backward()
177+
178+
gM = M.grad.clone()
179+
ga = a.grad.clone()
180+
gb = b.grad.clone()
181+
182+
# Note, gradients are invariant to change in constant so we center them
183+
cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6)
184+
tolerance = 0.96
185+
assert cos(gM0.flatten(), gM.flatten()) > tolerance
186+
assert cos(ga0 - ga0.mean(), ga - ga.mean()) > tolerance
187+
assert cos(gb0 - gb0.mean(), gb - gb.mean()) > tolerance
188+
189+
assert torch.allclose(sol0.plan, sol.plan)
190+
assert torch.allclose(sol0.value, sol.value)
191+
assert torch.allclose(sol0.value_linear, sol.value_linear)
192+
assert torch.allclose(sol0.potentials[0], sol.potentials[0])
193+
assert torch.allclose(sol0.potentials[1], sol.potentials[1])
194+
195+
with pytest.raises(ValueError):
196+
ot.solve(M, a, b, grad="last_step", max_iter=0, reg=10)
197+
198+
199+
@pytest.mark.skipif(not torch, reason="torch no installed")
200+
def test_solve_detach():
201+
n_samples_s = 10
202+
n_samples_t = 7
203+
n_features = 2
204+
rng = np.random.RandomState(0)
205+
206+
x = rng.randn(n_samples_s, n_features)
207+
y = rng.randn(n_samples_t, n_features)
208+
a = ot.utils.unif(n_samples_s)
209+
b = ot.utils.unif(n_samples_t)
210+
M = ot.dist(x, y)
211+
212+
# Check that last_step and autodiff give the same result and similar gradients
213+
a = torch.tensor(a, requires_grad=True)
214+
b = torch.tensor(b, requires_grad=True)
215+
M = torch.tensor(M, requires_grad=True)
216+
217+
sol0 = ot.solve(M, a, b, reg=10, grad="detach")
218+
219+
with pytest.raises(RuntimeError):
220+
sol0.value.backward()
221+
222+
sol = ot.solve(M, a, b, reg=10, grad="autodiff")
223+
224+
assert torch.allclose(sol0.plan, sol.plan)
225+
assert torch.allclose(sol0.value, sol.value)
226+
assert torch.allclose(sol0.value_linear, sol.value_linear)
227+
assert torch.allclose(sol0.potentials[0], sol.potentials[0])
228+
assert torch.allclose(sol0.potentials[1], sol.potentials[1])
229+
230+
146231
@pytest.mark.skipif(not torch, reason="torch no installed")
147232
def test_solve_envelope():
148233
n_samples_s = 10
@@ -178,7 +263,7 @@ def test_solve_envelope():
178263
ga = a.grad.clone()
179264
gb = b.grad.clone()
180265

181-
# Note, gradients aer invariant to change in constant so we center them
266+
# Note, gradients are invariant to change in constant so we center them
182267
assert torch.allclose(gM0, gM)
183268
assert torch.allclose(ga0 - ga0.mean(), ga - ga.mean())
184269
assert torch.allclose(gb0 - gb0.mean(), gb - gb.mean())

0 commit comments

Comments
 (0)